Backend Server Update

This commit is contained in:
2025-03-29 22:06:00 -04:00
parent e8a9207da4
commit 06fa7936a3
3 changed files with 360 additions and 284 deletions

View File

@@ -5,6 +5,8 @@ import asyncio
import torch
import torchaudio
import numpy as np
import io
import whisperx
from io import BytesIO
from typing import List, Dict, Any, Optional
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
@@ -13,6 +15,7 @@ from pydantic import BaseModel
from generator import load_csm_1b, Segment
import uvicorn
import time
import gc
from collections import deque
# Select device
@@ -25,6 +28,12 @@ print(f"Using device: {device}")
# Initialize the model
generator = load_csm_1b(device=device)
# Initialize WhisperX for ASR
print("Loading WhisperX model...")
# Use a smaller model for faster response times
asr_model = whisperx.load_model("medium", device, compute_type="float16")
print("WhisperX model loaded!")
app = FastAPI()
# Add CORS middleware to allow cross-origin requests
@@ -93,6 +102,68 @@ async def encode_audio_data(audio_tensor: torch.Tensor) -> str:
return f"data:audio/wav;base64,{audio_base64}"
async def transcribe_audio(audio_tensor: torch.Tensor) -> str:
"""Transcribe audio using WhisperX"""
try:
# Save the tensor to a temporary file
temp_file = BytesIO()
torchaudio.save(temp_file, audio_tensor.unsqueeze(0).cpu(), generator.sample_rate, format="wav")
temp_file.seek(0)
# Create a temporary file on disk (WhisperX requires a file path)
temp_path = "temp_audio.wav"
with open(temp_path, "wb") as f:
f.write(temp_file.read())
# Load and transcribe the audio
audio = whisperx.load_audio(temp_path)
result = asr_model.transcribe(audio, batch_size=16)
# Clean up
os.remove(temp_path)
# Get the transcription text
if result["segments"] and len(result["segments"]) > 0:
# Combine all segments
transcription = " ".join([segment["text"] for segment in result["segments"]])
print(f"Transcription: {transcription}")
return transcription.strip()
else:
return ""
except Exception as e:
print(f"Error in transcription: {str(e)}")
return ""
async def generate_response(text: str, conversation_history: List[Segment]) -> str:
"""Generate a contextual response based on the transcribed text"""
# Simple response logic - can be replaced with a more sophisticated LLM in the future
responses = {
"hello": "Hello there! How are you doing today?",
"how are you": "I'm doing well, thanks for asking! How about you?",
"what is your name": "I'm Sesame, your voice assistant. How can I help you?",
"bye": "Goodbye! It was nice chatting with you.",
"thank you": "You're welcome! Is there anything else I can help with?",
"weather": "I don't have real-time weather data, but I hope it's nice where you are!",
"help": "I can chat with you using natural voice. Just speak normally and I'll respond.",
}
text_lower = text.lower()
# Check for matching keywords
for key, response in responses.items():
if key in text_lower:
return response
# Default responses based on text length
if not text:
return "I didn't catch that. Could you please repeat?"
elif len(text) < 10:
return "Thanks for your message. Could you elaborate a bit more?"
else:
return f"I understand you said '{text}'. That's interesting! Can you tell me more about that?"
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
await manager.connect(websocket)
@@ -220,30 +291,55 @@ async def websocket_endpoint(websocket: WebSocket):
# User has stopped talking - process the collected audio
full_audio = torch.cat(streaming_buffer, dim=0)
# Process with speech-to-text (you would need to implement this)
# For now, just use a placeholder text
text = f"User audio from speaker {speaker_id}"
# Process with WhisperX speech-to-text
transcribed_text = await transcribe_audio(full_audio)
print(f"Detected end of speech, processing {len(streaming_buffer)} chunks")
# Log the transcription
print(f"Transcribed text: '{transcribed_text}'")
# Add to conversation context
context_segments.append(Segment(text=text, speaker=speaker_id, audio=full_audio))
# Generate response
response_text = "This is a response to what you just said"
audio_tensor = generator.generate(
text=response_text,
speaker=1 if speaker_id == 0 else 0, # Use opposite speaker
context=context_segments,
max_audio_length_ms=10_000,
)
# Convert audio to base64 and send back to client
audio_base64 = await encode_audio_data(audio_tensor)
await websocket.send_json({
"type": "audio_response",
"audio": audio_base64
})
if transcribed_text:
user_segment = Segment(text=transcribed_text, speaker=speaker_id, audio=full_audio)
context_segments.append(user_segment)
# Generate a contextual response
response_text = await generate_response(transcribed_text, context_segments)
# Send the transcribed text to client
await websocket.send_json({
"type": "transcription",
"text": transcribed_text
})
# Generate audio for the response
audio_tensor = generator.generate(
text=response_text,
speaker=1 if speaker_id == 0 else 0, # Use opposite speaker
context=context_segments,
max_audio_length_ms=10_000,
)
# Add response to context
ai_segment = Segment(
text=response_text,
speaker=1 if speaker_id == 0 else 0,
audio=audio_tensor
)
context_segments.append(ai_segment)
# Convert audio to base64 and send back to client
audio_base64 = await encode_audio_data(audio_tensor)
await websocket.send_json({
"type": "audio_response",
"text": response_text,
"audio": audio_base64
})
else:
# If transcription failed, send a generic response
await websocket.send_json({
"type": "error",
"message": "Sorry, I couldn't understand what you said. Could you try again?"
})
# Clear buffer and reset silence detection
streaming_buffer = []
@@ -256,8 +352,19 @@ async def websocket_endpoint(websocket: WebSocket):
elif len(streaming_buffer) >= 30: # ~6 seconds of audio at 5 chunks/sec
print("Buffer limit reached, processing audio")
full_audio = torch.cat(streaming_buffer, dim=0)
text = f"Continued speech from speaker {speaker_id}"
context_segments.append(Segment(text=text, speaker=speaker_id, audio=full_audio))
# Process with WhisperX speech-to-text
transcribed_text = await transcribe_audio(full_audio)
if transcribed_text:
context_segments.append(Segment(text=transcribed_text, speaker=speaker_id, audio=full_audio))
# Send the transcribed text to client
await websocket.send_json({
"type": "transcription",
"text": transcribed_text + " (processing continued speech...)"
})
streaming_buffer = []
except Exception as e:
@@ -269,11 +376,21 @@ async def websocket_endpoint(websocket: WebSocket):
elif action == "stop_streaming":
is_streaming = False
if streaming_buffer:
if streaming_buffer and len(streaming_buffer) > 5: # Only process if there's meaningful audio
# Process any remaining audio in the buffer
full_audio = torch.cat(streaming_buffer, dim=0)
text = f"Final streaming audio from speaker {request.get('speaker', 0)}"
context_segments.append(Segment(text=text, speaker=request.get("speaker", 0), audio=full_audio))
# Process with WhisperX speech-to-text
transcribed_text = await transcribe_audio(full_audio)
if transcribed_text:
context_segments.append(Segment(text=transcribed_text, speaker=request.get("speaker", 0), audio=full_audio))
# Send the transcribed text to client
await websocket.send_json({
"type": "transcription",
"text": transcribed_text
})
streaming_buffer = []
await websocket.send_json({
@@ -286,12 +403,15 @@ async def websocket_endpoint(websocket: WebSocket):
print("Client disconnected")
except Exception as e:
print(f"Error: {str(e)}")
await websocket.send_json({
"type": "error",
"message": str(e)
})
try:
await websocket.send_json({
"type": "error",
"message": str(e)
})
except:
pass
manager.disconnect(websocket)
if __name__ == "__main__":
uvicorn.run(app, host="localhost", port=8000)
uvicorn.run(app, host="0.0.0.0", port=8000)