Demo Fixes 19

This commit is contained in:
2025-03-30 08:59:26 -04:00
parent eff4b65c3b
commit 8695dd0297
2 changed files with 694 additions and 680 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -14,8 +14,8 @@ import huggingface_hub
from generator import load_csm_1b, Segment from generator import load_csm_1b, Segment
import threading import threading
import queue import queue
from flask import stream_with_context, Response import asyncio
import time import json
# Configure environment with longer timeouts # Configure environment with longer timeouts
os.environ["HF_HUB_DOWNLOAD_TIMEOUT"] = "600" # 10 minutes timeout for downloads os.environ["HF_HUB_DOWNLOAD_TIMEOUT"] = "600" # 10 minutes timeout for downloads
@@ -26,7 +26,7 @@ os.makedirs("models", exist_ok=True)
app = Flask(__name__) app = Flask(__name__)
app.config['SECRET_KEY'] = 'your-secret-key' app.config['SECRET_KEY'] = 'your-secret-key'
socketio = SocketIO(app, cors_allowed_origins="*") socketio = SocketIO(app, cors_allowed_origins="*", async_mode='eventlet')
# Explicitly check for CUDA and print more detailed info # Explicitly check for CUDA and print more detailed info
print("\n=== CUDA Information ===") print("\n=== CUDA Information ===")
@@ -128,8 +128,7 @@ def load_models():
# Store conversation context # Store conversation context
conversation_context = {} # session_id -> context conversation_context = {} # session_id -> context
CHUNK_SIZE = 24000 # Number of audio samples per chunk (1 second at 24kHz) active_audio_streams = {} # session_id -> stream status
audio_stream_queues = {} # session_id -> queue for audio chunks
@app.route('/') @app.route('/')
def index(): def index():
@@ -143,9 +142,14 @@ def handle_connect():
'speakers': [0, 1], # 0 = user, 1 = bot 'speakers': [0, 1], # 0 = user, 1 = bot
'audio_buffer': deque(maxlen=10), # Store recent audio chunks 'audio_buffer': deque(maxlen=10), # Store recent audio chunks
'is_speaking': False, 'is_speaking': False,
'silence_start': None 'last_activity': time.time(),
'active_session': True,
'transcription_buffer': [] # For real-time transcription
} }
emit('ready', {'message': 'Connection established'}) emit('ready', {
'message': 'Connection established',
'sample_rate': getattr(csm_generator, 'sample_rate', 24000) if csm_generator else 24000
})
@socketio.on('disconnect') @socketio.on('disconnect')
def handle_disconnect(): def handle_disconnect():
@@ -154,56 +158,130 @@ def handle_disconnect():
# Clean up resources # Clean up resources
if session_id in conversation_context: if session_id in conversation_context:
conversation_context[session_id]['active_session'] = False
del conversation_context[session_id] del conversation_context[session_id]
if session_id in audio_stream_queues: if session_id in active_audio_streams:
del audio_stream_queues[session_id] active_audio_streams[session_id]['active'] = False
del active_audio_streams[session_id]
@socketio.on('start_speaking') @socketio.on('audio_stream')
def handle_start_speaking(): def handle_audio_stream(data):
if request.sid in conversation_context: """Handle incoming audio stream from client"""
conversation_context[request.sid]['is_speaking'] = True session_id = request.sid
conversation_context[request.sid]['audio_buffer'].clear()
print(f"User {request.sid} started speaking")
@socketio.on('audio_chunk') if session_id not in conversation_context:
def handle_audio_chunk(data):
if request.sid not in conversation_context:
return return
context = conversation_context[request.sid] context = conversation_context[session_id]
context['last_activity'] = time.time()
# Process different stream events
if data.get('event') == 'start':
# Client is starting to send audio
context['is_speaking'] = True
context['audio_buffer'].clear()
context['transcription_buffer'] = []
print(f"User {session_id} started streaming audio")
# If AI was speaking, interrupt it
if session_id in active_audio_streams and active_audio_streams[session_id]['active']:
active_audio_streams[session_id]['active'] = False
emit('ai_stream_interrupt', {}, room=session_id)
elif data.get('event') == 'data':
# Audio data received
if not context['is_speaking']:
return
# Decode audio chunk
try:
audio_data = base64.b64decode(data.get('audio', ''))
if not audio_data:
return
# Decode audio data
audio_data = base64.b64decode(data['audio'])
audio_numpy = np.frombuffer(audio_data, dtype=np.float32) audio_numpy = np.frombuffer(audio_data, dtype=np.float32)
# Apply a simple noise gate
if np.mean(np.abs(audio_numpy)) < 0.01: # Very quiet
return
audio_tensor = torch.tensor(audio_numpy) audio_tensor = torch.tensor(audio_numpy)
# Add to buffer # Add to audio buffer
context['audio_buffer'].append(audio_tensor) context['audio_buffer'].append(audio_tensor)
# Check for silence to detect end of speech # Real-time transcription (periodic)
if context['is_speaking'] and is_silence(audio_tensor): if len(context['audio_buffer']) % 3 == 0: # Process every 3 chunks
if context['silence_start'] is None: threading.Thread(
context['silence_start'] = time.time() target=process_realtime_transcription,
elif time.time() - context['silence_start'] > 1.0: # 1 second of silence args=(session_id,),
daemon=True
).start()
except Exception as e:
print(f"Error processing audio chunk: {e}")
elif data.get('event') == 'end':
# Client has finished sending audio
context['is_speaking'] = False
if len(context['audio_buffer']) > 0:
# Process the complete utterance # Process the complete utterance
process_user_utterance(request.sid) threading.Thread(
else: target=process_complete_utterance,
context['silence_start'] = None args=(session_id,),
daemon=True
).start()
@socketio.on('stop_speaking') print(f"User {session_id} stopped streaming audio")
def handle_stop_speaking():
if request.sid in conversation_context:
conversation_context[request.sid]['is_speaking'] = False
process_user_utterance(request.sid)
print(f"User {request.sid} stopped speaking")
def is_silence(audio_tensor, threshold=0.02): def process_realtime_transcription(session_id):
"""Check if an audio chunk is silence based on amplitude threshold""" """Process incoming audio for real-time transcription"""
return torch.mean(torch.abs(audio_tensor)) < threshold if session_id not in conversation_context or not conversation_context[session_id]['active_session']:
return
def process_user_utterance(session_id): context = conversation_context[session_id]
if not context['audio_buffer'] or not context['is_speaking']:
return
try:
# Combine current buffer for transcription
buffer_copy = list(context['audio_buffer'])
if not buffer_copy:
return
full_audio = torch.cat(buffer_copy, dim=0)
# Save audio to temporary WAV file for transcription
temp_audio_path = f"temp_rt_{session_id}.wav"
torchaudio.save(
temp_audio_path,
full_audio.unsqueeze(0),
44100 # Assuming 44.1kHz from client
)
# Transcribe with Whisper if available
if whisper_model is not None:
segments, _ = whisper_model.transcribe(temp_audio_path, beam_size=5)
text = " ".join([segment.text for segment in segments])
if text.strip():
context['transcription_buffer'].append(text)
# Send partial transcription to client
emit('partial_transcription', {'text': text}, room=session_id)
except Exception as e:
print(f"Error in realtime transcription: {e}")
finally:
# Clean up
if os.path.exists(temp_audio_path):
os.remove(temp_audio_path)
def process_complete_utterance(session_id):
"""Process completed user utterance, generate response and stream audio back""" """Process completed user utterance, generate response and stream audio back"""
if session_id not in conversation_context or not conversation_context[session_id]['active_session']:
return
context = conversation_context[session_id] context = conversation_context[session_id]
if not context['audio_buffer']: if not context['audio_buffer']:
@@ -212,8 +290,6 @@ def process_user_utterance(session_id):
# Combine audio chunks # Combine audio chunks
full_audio = torch.cat(list(context['audio_buffer']), dim=0) full_audio = torch.cat(list(context['audio_buffer']), dim=0)
context['audio_buffer'].clear() context['audio_buffer'].clear()
context['is_speaking'] = False
context['silence_start'] = None
# Save audio to temporary WAV file for transcription # Save audio to temporary WAV file for transcription
temp_audio_path = f"temp_audio_{session_id}.wav" temp_audio_path = f"temp_audio_{session_id}.wav"
@@ -255,23 +331,23 @@ def process_user_utterance(session_id):
# Generate and stream audio response if CSM is available # Generate and stream audio response if CSM is available
if csm_generator is not None: if csm_generator is not None:
# Set up streaming queue for this session # Create stream state object
if session_id not in audio_stream_queues: active_audio_streams[session_id] = {
audio_stream_queues[session_id] = queue.Queue() 'active': True,
else: 'text': bot_response
# Clear any existing items in the queue }
while not audio_stream_queues[session_id].empty():
audio_stream_queues[session_id].get()
# Start audio generation in a separate thread to not block the server # Send initial response to prepare client
emit('ai_stream_start', {
'text': bot_response
}, room=session_id)
# Start audio generation in a separate thread
threading.Thread( threading.Thread(
target=generate_and_stream_audio, target=generate_and_stream_audio_realtime,
args=(bot_response, context['segments'], session_id), args=(bot_response, context['segments'], session_id),
daemon=True daemon=True
).start() ).start()
# Initial response with text
emit('start_streaming_response', {'text': bot_response}, room=session_id)
else: else:
# Send text-only response if audio generation isn't available # Send text-only response if audio generation isn't available
emit('text_response', {'text': bot_response}, room=session_id) emit('text_response', {'text': bot_response}, room=session_id)
@@ -378,8 +454,11 @@ def fallback_response(user_text):
else: else:
return "I understand you said something about that. Unfortunately, I'm running in fallback mode with limited capabilities. Please try again later when the main model is available." return "I understand you said something about that. Unfortunately, I'm running in fallback mode with limited capabilities. Please try again later when the main model is available."
def generate_audio_response(text, conversation_segments): def generate_and_stream_audio_realtime(text, conversation_segments, session_id):
"""Generate audio response using CSM""" """Generate audio response using CSM and stream it in real-time to client"""
if session_id not in active_audio_streams or not active_audio_streams[session_id]['active']:
return
try: try:
# Use the last few conversation segments as context # Use the last few conversation segments as context
context_segments = conversation_segments[-4:] if len(conversation_segments) > 4 else conversation_segments context_segments = conversation_segments[-4:] if len(conversation_segments) > 4 else conversation_segments
@@ -394,40 +473,23 @@ def generate_audio_response(text, conversation_segments):
topk=50 topk=50
) )
return audio
except Exception as e:
print(f"Error generating audio: {e}")
# Return silence as fallback
return torch.zeros(csm_generator.sample_rate * 3) # 3 seconds of silence
def generate_and_stream_audio(text, conversation_segments, session_id):
"""Generate audio response using CSM and stream it in chunks"""
try:
# Use the last few conversation segments as context
context_segments = conversation_segments[-4:] if len(conversation_segments) > 4 else conversation_segments
# Generate full audio for bot response
audio = csm_generator.generate(
text=text,
speaker=1, # Bot is speaker 1
context=context_segments,
max_audio_length_ms=10000, # 10 seconds max
temperature=0.9,
topk=50
)
# Store the full audio for conversation history # Store the full audio for conversation history
bot_segment = Segment( bot_segment = Segment(
text=text, text=text,
speaker=1, # Bot is speaker 1 speaker=1, # Bot is speaker 1
audio=audio audio=audio
) )
if session_id in conversation_context: if session_id in conversation_context and conversation_context[session_id]['active_session']:
conversation_context[session_id]['segments'].append(bot_segment) conversation_context[session_id]['segments'].append(bot_segment)
# Split audio into chunks for streaming # Stream audio in small chunks for more responsive playback
chunk_size = CHUNK_SIZE chunk_size = 4800 # 200ms at 24kHz
for i in range(0, len(audio), chunk_size): for i in range(0, len(audio), chunk_size):
if session_id not in active_audio_streams or not active_audio_streams[session_id]['active']:
print("Audio streaming interrupted or session ended")
break
chunk = audio[i:i+chunk_size] chunk = audio[i:i+chunk_size]
# Convert audio chunk to base64 for streaming # Convert audio chunk to base64 for streaming
@@ -436,61 +498,33 @@ def generate_and_stream_audio(text, conversation_segments, session_id):
audio_bytes.seek(0) audio_bytes.seek(0)
audio_b64 = base64.b64encode(audio_bytes.read()).decode('utf-8') audio_b64 = base64.b64encode(audio_bytes.read()).decode('utf-8')
# Send the chunk to the client # Send chunk to client
if session_id in audio_stream_queues: socketio.emit('ai_stream_data', {
audio_stream_queues[session_id].put({
'audio': audio_b64, 'audio': audio_b64,
'is_last': i + chunk_size >= len(audio) 'is_last': i + chunk_size >= len(audio)
}) }, room=session_id)
else:
# Session was disconnected before we finished generating
break
# Signal the end of streaming if queue still exists # Simulate real-time speech by adding a small delay
if session_id in audio_stream_queues: # Remove this in production for faster response
# Add an empty chunk as a sentinel to signal end of streaming time.sleep(0.15) # Slight delay for more natural timing
audio_stream_queues[session_id].put(None)
# Signal end of stream
if session_id in active_audio_streams and active_audio_streams[session_id]['active']:
socketio.emit('ai_stream_end', {}, room=session_id)
active_audio_streams[session_id]['active'] = False
except Exception as e: except Exception as e:
print(f"Error generating or streaming audio: {e}") print(f"Error generating or streaming audio: {e}")
# Send error message to client # Send error message to client
if session_id in conversation_context: if session_id in conversation_context and conversation_context[session_id]['active_session']:
socketio.emit('error', { socketio.emit('error', {
'message': f'Error generating audio: {str(e)}' 'message': f'Error generating audio: {str(e)}'
}, room=session_id) }, room=session_id)
# Send a final message to unblock the client # Signal stream end to unblock client
if session_id in audio_stream_queues: socketio.emit('ai_stream_end', {}, room=session_id)
audio_stream_queues[session_id].put(None) if session_id in active_audio_streams:
active_audio_streams[session_id]['active'] = False
@socketio.on('request_audio_chunk')
def handle_request_audio_chunk():
"""Send the next audio chunk in the queue to the client"""
session_id = request.sid
if session_id not in audio_stream_queues:
emit('error', {'message': 'No audio stream available'})
return
# Get the next chunk or wait for it to be available
try:
if not audio_stream_queues[session_id].empty():
chunk = audio_stream_queues[session_id].get(block=False)
# If chunk is None, we're done streaming
if chunk is None:
emit('end_streaming')
# Clean up the queue
if session_id in audio_stream_queues:
del audio_stream_queues[session_id]
else:
emit('audio_chunk', chunk)
else:
# If the queue is empty but we're still generating, tell client to wait
emit('wait_for_chunk')
except Exception as e:
print(f"Error sending audio chunk: {e}")
emit('error', {'message': f'Error streaming audio: {str(e)}'})
if __name__ == '__main__': if __name__ == '__main__':
# Ensure the existing index.html file is in the correct location # Ensure the existing index.html file is in the correct location
@@ -500,10 +534,10 @@ if __name__ == '__main__':
if os.path.exists('index.html') and not os.path.exists('templates/index.html'): if os.path.exists('index.html') and not os.path.exists('templates/index.html'):
os.rename('index.html', 'templates/index.html') os.rename('index.html', 'templates/index.html')
# Load models asynchronously before starting the server # Load models before starting the server
print("Starting model loading...") print("Starting model loading...")
load_models() load_models()
# Start the server # Start the server with eventlet for better WebSocket performance
print("Starting Flask SocketIO server...") print("Starting Flask SocketIO server...")
socketio.run(app, host='0.0.0.0', port=5000, debug=False) socketio.run(app, host='0.0.0.0', port=5000, debug=False)