Demo Fixes 14

This commit is contained in:
2025-03-30 08:36:50 -04:00
parent a55b3f52a4
commit 12383d5e8b
2 changed files with 475 additions and 33 deletions

View File

@@ -12,6 +12,10 @@ from collections import deque
import requests
import huggingface_hub
from generator import load_csm_1b, Segment
import threading
import queue
from flask import stream_with_context, Response
import time
# Configure environment with longer timeouts
os.environ["HF_HUB_DOWNLOAD_TIMEOUT"] = "600" # 10 minutes timeout for downloads
@@ -124,6 +128,8 @@ def load_models():
# Store conversation context
conversation_context = {} # session_id -> context
CHUNK_SIZE = 24000 # Number of audio samples per chunk (1 second at 24kHz)
audio_stream_queues = {} # session_id -> queue for audio chunks
@app.route('/')
def index():
@@ -144,8 +150,14 @@ def handle_connect():
@socketio.on('disconnect')
def handle_disconnect():
print(f"Client disconnected: {request.sid}")
if request.sid in conversation_context:
del conversation_context[request.sid]
session_id = request.sid
# Clean up resources
if session_id in conversation_context:
del conversation_context[session_id]
if session_id in audio_stream_queues:
del audio_stream_queues[session_id]
@socketio.on('start_speaking')
def handle_start_speaking():
@@ -191,7 +203,7 @@ def is_silence(audio_tensor, threshold=0.02):
return torch.mean(torch.abs(audio_tensor)) < threshold
def process_user_utterance(session_id):
"""Process completed user utterance, generate response and send audio back"""
"""Process completed user utterance, generate response and stream audio back"""
context = conversation_context[session_id]
if not context['audio_buffer']:
@@ -234,37 +246,32 @@ def process_user_utterance(session_id):
)
context['segments'].append(user_segment)
# Generate bot response
# Generate bot response text
bot_response = generate_llm_response(user_text, context['segments'])
print(f"Bot response: {bot_response}")
# Send transcribed text to client
emit('transcription', {'text': user_text}, room=session_id)
# Generate and send audio response if CSM is available
# Generate and stream audio response if CSM is available
if csm_generator is not None:
# Convert to audio using CSM
bot_audio = generate_audio_response(bot_response, context['segments'])
# Set up streaming queue for this session
if session_id not in audio_stream_queues:
audio_stream_queues[session_id] = queue.Queue()
else:
# Clear any existing items in the queue
while not audio_stream_queues[session_id].empty():
audio_stream_queues[session_id].get()
# Convert audio to base64 for sending over websocket
audio_bytes = io.BytesIO()
torchaudio.save(audio_bytes, bot_audio.unsqueeze(0).cpu(), csm_generator.sample_rate, format="wav")
audio_bytes.seek(0)
audio_b64 = base64.b64encode(audio_bytes.read()).decode('utf-8')
# Start audio generation in a separate thread to not block the server
threading.Thread(
target=generate_and_stream_audio,
args=(bot_response, context['segments'], session_id),
daemon=True
).start()
# Add bot response to conversation history
bot_segment = Segment(
text=bot_response,
speaker=1, # Bot is speaker 1
audio=bot_audio
)
context['segments'].append(bot_segment)
# Send audio response to client
emit('audio_response', {
'audio': audio_b64,
'text': bot_response
}, room=session_id)
# Initial response with text
emit('start_streaming_response', {'text': bot_response}, room=session_id)
else:
# Send text-only response if audio generation isn't available
emit('text_response', {'text': bot_response}, room=session_id)
@@ -391,6 +398,98 @@ def generate_audio_response(text, conversation_segments):
# Return silence as fallback
return torch.zeros(csm_generator.sample_rate * 3) # 3 seconds of silence
def generate_and_stream_audio(text, conversation_segments, session_id):
"""Generate audio response using CSM and stream it in chunks"""
try:
# Use the last few conversation segments as context
context_segments = conversation_segments[-4:] if len(conversation_segments) > 4 else conversation_segments
# Generate full audio for bot response
audio = csm_generator.generate(
text=text,
speaker=1, # Bot is speaker 1
context=context_segments,
max_audio_length_ms=10000, # 10 seconds max
temperature=0.9,
topk=50
)
# Store the full audio for conversation history
bot_segment = Segment(
text=text,
speaker=1, # Bot is speaker 1
audio=audio
)
if session_id in conversation_context:
conversation_context[session_id]['segments'].append(bot_segment)
# Split audio into chunks for streaming
chunk_size = CHUNK_SIZE
for i in range(0, len(audio), chunk_size):
chunk = audio[i:i+chunk_size]
# Convert audio chunk to base64 for streaming
audio_bytes = io.BytesIO()
torchaudio.save(audio_bytes, chunk.unsqueeze(0).cpu(), csm_generator.sample_rate, format="wav")
audio_bytes.seek(0)
audio_b64 = base64.b64encode(audio_bytes.read()).decode('utf-8')
# Send the chunk to the client
if session_id in audio_stream_queues:
audio_stream_queues[session_id].put({
'audio': audio_b64,
'is_last': i + chunk_size >= len(audio)
})
else:
# Session was disconnected before we finished generating
break
# Signal the end of streaming if queue still exists
if session_id in audio_stream_queues:
# Add an empty chunk as a sentinel to signal end of streaming
audio_stream_queues[session_id].put(None)
except Exception as e:
print(f"Error generating or streaming audio: {e}")
# Send error message to client
if session_id in conversation_context:
socketio.emit('error', {
'message': f'Error generating audio: {str(e)}'
}, room=session_id)
# Send a final message to unblock the client
if session_id in audio_stream_queues:
audio_stream_queues[session_id].put(None)
@socketio.on('request_audio_chunk')
def handle_request_audio_chunk():
"""Send the next audio chunk in the queue to the client"""
session_id = request.sid
if session_id not in audio_stream_queues:
emit('error', {'message': 'No audio stream available'})
return
# Get the next chunk or wait for it to be available
try:
if not audio_stream_queues[session_id].empty():
chunk = audio_stream_queues[session_id].get(block=False)
# If chunk is None, we're done streaming
if chunk is None:
emit('end_streaming')
# Clean up the queue
if session_id in audio_stream_queues:
del audio_stream_queues[session_id]
else:
emit('audio_chunk', chunk)
else:
# If the queue is empty but we're still generating, tell client to wait
emit('wait_for_chunk')
except Exception as e:
print(f"Error sending audio chunk: {e}")
emit('error', {'message': f'Error streaming audio: {str(e)}'})
if __name__ == '__main__':
# Ensure the existing index.html file is in the correct location
if not os.path.exists('templates'):