Demo Fixes 19

This commit is contained in:
2025-03-30 08:59:26 -04:00
parent eff4b65c3b
commit 8695dd0297
2 changed files with 694 additions and 680 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -14,8 +14,8 @@ import huggingface_hub
from generator import load_csm_1b, Segment
import threading
import queue
from flask import stream_with_context, Response
import time
import asyncio
import json
# Configure environment with longer timeouts
os.environ["HF_HUB_DOWNLOAD_TIMEOUT"] = "600" # 10 minutes timeout for downloads
@@ -26,7 +26,7 @@ os.makedirs("models", exist_ok=True)
app = Flask(__name__)
app.config['SECRET_KEY'] = 'your-secret-key'
socketio = SocketIO(app, cors_allowed_origins="*")
socketio = SocketIO(app, cors_allowed_origins="*", async_mode='eventlet')
# Explicitly check for CUDA and print more detailed info
print("\n=== CUDA Information ===")
@@ -128,8 +128,7 @@ def load_models():
# Store conversation context
conversation_context = {} # session_id -> context
CHUNK_SIZE = 24000 # Number of audio samples per chunk (1 second at 24kHz)
audio_stream_queues = {} # session_id -> queue for audio chunks
active_audio_streams = {} # session_id -> stream status
@app.route('/')
def index():
@@ -143,9 +142,14 @@ def handle_connect():
'speakers': [0, 1], # 0 = user, 1 = bot
'audio_buffer': deque(maxlen=10), # Store recent audio chunks
'is_speaking': False,
'silence_start': None
'last_activity': time.time(),
'active_session': True,
'transcription_buffer': [] # For real-time transcription
}
emit('ready', {'message': 'Connection established'})
emit('ready', {
'message': 'Connection established',
'sample_rate': getattr(csm_generator, 'sample_rate', 24000) if csm_generator else 24000
})
@socketio.on('disconnect')
def handle_disconnect():
@@ -154,56 +158,130 @@ def handle_disconnect():
# Clean up resources
if session_id in conversation_context:
conversation_context[session_id]['active_session'] = False
del conversation_context[session_id]
if session_id in audio_stream_queues:
del audio_stream_queues[session_id]
if session_id in active_audio_streams:
active_audio_streams[session_id]['active'] = False
del active_audio_streams[session_id]
@socketio.on('start_speaking')
def handle_start_speaking():
if request.sid in conversation_context:
conversation_context[request.sid]['is_speaking'] = True
conversation_context[request.sid]['audio_buffer'].clear()
print(f"User {request.sid} started speaking")
@socketio.on('audio_chunk')
def handle_audio_chunk(data):
if request.sid not in conversation_context:
@socketio.on('audio_stream')
def handle_audio_stream(data):
"""Handle incoming audio stream from client"""
session_id = request.sid
if session_id not in conversation_context:
return
context = conversation_context[request.sid]
context = conversation_context[session_id]
context['last_activity'] = time.time()
# Decode audio data
audio_data = base64.b64decode(data['audio'])
audio_numpy = np.frombuffer(audio_data, dtype=np.float32)
audio_tensor = torch.tensor(audio_numpy)
# Process different stream events
if data.get('event') == 'start':
# Client is starting to send audio
context['is_speaking'] = True
context['audio_buffer'].clear()
context['transcription_buffer'] = []
print(f"User {session_id} started streaming audio")
# If AI was speaking, interrupt it
if session_id in active_audio_streams and active_audio_streams[session_id]['active']:
active_audio_streams[session_id]['active'] = False
emit('ai_stream_interrupt', {}, room=session_id)
# Add to buffer
context['audio_buffer'].append(audio_tensor)
elif data.get('event') == 'data':
# Audio data received
if not context['is_speaking']:
return
# Decode audio chunk
try:
audio_data = base64.b64decode(data.get('audio', ''))
if not audio_data:
return
audio_numpy = np.frombuffer(audio_data, dtype=np.float32)
# Apply a simple noise gate
if np.mean(np.abs(audio_numpy)) < 0.01: # Very quiet
return
audio_tensor = torch.tensor(audio_numpy)
# Add to audio buffer
context['audio_buffer'].append(audio_tensor)
# Real-time transcription (periodic)
if len(context['audio_buffer']) % 3 == 0: # Process every 3 chunks
threading.Thread(
target=process_realtime_transcription,
args=(session_id,),
daemon=True
).start()
except Exception as e:
print(f"Error processing audio chunk: {e}")
# Check for silence to detect end of speech
if context['is_speaking'] and is_silence(audio_tensor):
if context['silence_start'] is None:
context['silence_start'] = time.time()
elif time.time() - context['silence_start'] > 1.0: # 1 second of silence
elif data.get('event') == 'end':
# Client has finished sending audio
context['is_speaking'] = False
if len(context['audio_buffer']) > 0:
# Process the complete utterance
process_user_utterance(request.sid)
else:
context['silence_start'] = None
threading.Thread(
target=process_complete_utterance,
args=(session_id,),
daemon=True
).start()
print(f"User {session_id} stopped streaming audio")
@socketio.on('stop_speaking')
def handle_stop_speaking():
if request.sid in conversation_context:
conversation_context[request.sid]['is_speaking'] = False
process_user_utterance(request.sid)
print(f"User {request.sid} stopped speaking")
def process_realtime_transcription(session_id):
"""Process incoming audio for real-time transcription"""
if session_id not in conversation_context or not conversation_context[session_id]['active_session']:
return
context = conversation_context[session_id]
if not context['audio_buffer'] or not context['is_speaking']:
return
try:
# Combine current buffer for transcription
buffer_copy = list(context['audio_buffer'])
if not buffer_copy:
return
full_audio = torch.cat(buffer_copy, dim=0)
# Save audio to temporary WAV file for transcription
temp_audio_path = f"temp_rt_{session_id}.wav"
torchaudio.save(
temp_audio_path,
full_audio.unsqueeze(0),
44100 # Assuming 44.1kHz from client
)
# Transcribe with Whisper if available
if whisper_model is not None:
segments, _ = whisper_model.transcribe(temp_audio_path, beam_size=5)
text = " ".join([segment.text for segment in segments])
if text.strip():
context['transcription_buffer'].append(text)
# Send partial transcription to client
emit('partial_transcription', {'text': text}, room=session_id)
except Exception as e:
print(f"Error in realtime transcription: {e}")
finally:
# Clean up
if os.path.exists(temp_audio_path):
os.remove(temp_audio_path)
def is_silence(audio_tensor, threshold=0.02):
"""Check if an audio chunk is silence based on amplitude threshold"""
return torch.mean(torch.abs(audio_tensor)) < threshold
def process_user_utterance(session_id):
def process_complete_utterance(session_id):
"""Process completed user utterance, generate response and stream audio back"""
if session_id not in conversation_context or not conversation_context[session_id]['active_session']:
return
context = conversation_context[session_id]
if not context['audio_buffer']:
@@ -212,8 +290,6 @@ def process_user_utterance(session_id):
# Combine audio chunks
full_audio = torch.cat(list(context['audio_buffer']), dim=0)
context['audio_buffer'].clear()
context['is_speaking'] = False
context['silence_start'] = None
# Save audio to temporary WAV file for transcription
temp_audio_path = f"temp_audio_{session_id}.wav"
@@ -255,23 +331,23 @@ def process_user_utterance(session_id):
# Generate and stream audio response if CSM is available
if csm_generator is not None:
# Set up streaming queue for this session
if session_id not in audio_stream_queues:
audio_stream_queues[session_id] = queue.Queue()
else:
# Clear any existing items in the queue
while not audio_stream_queues[session_id].empty():
audio_stream_queues[session_id].get()
# Create stream state object
active_audio_streams[session_id] = {
'active': True,
'text': bot_response
}
# Start audio generation in a separate thread to not block the server
# Send initial response to prepare client
emit('ai_stream_start', {
'text': bot_response
}, room=session_id)
# Start audio generation in a separate thread
threading.Thread(
target=generate_and_stream_audio,
target=generate_and_stream_audio_realtime,
args=(bot_response, context['segments'], session_id),
daemon=True
).start()
# Initial response with text
emit('start_streaming_response', {'text': bot_response}, room=session_id)
else:
# Send text-only response if audio generation isn't available
emit('text_response', {'text': bot_response}, room=session_id)
@@ -378,8 +454,11 @@ def fallback_response(user_text):
else:
return "I understand you said something about that. Unfortunately, I'm running in fallback mode with limited capabilities. Please try again later when the main model is available."
def generate_audio_response(text, conversation_segments):
"""Generate audio response using CSM"""
def generate_and_stream_audio_realtime(text, conversation_segments, session_id):
"""Generate audio response using CSM and stream it in real-time to client"""
if session_id not in active_audio_streams or not active_audio_streams[session_id]['active']:
return
try:
# Use the last few conversation segments as context
context_segments = conversation_segments[-4:] if len(conversation_segments) > 4 else conversation_segments
@@ -394,40 +473,23 @@ def generate_audio_response(text, conversation_segments):
topk=50
)
return audio
except Exception as e:
print(f"Error generating audio: {e}")
# Return silence as fallback
return torch.zeros(csm_generator.sample_rate * 3) # 3 seconds of silence
def generate_and_stream_audio(text, conversation_segments, session_id):
"""Generate audio response using CSM and stream it in chunks"""
try:
# Use the last few conversation segments as context
context_segments = conversation_segments[-4:] if len(conversation_segments) > 4 else conversation_segments
# Generate full audio for bot response
audio = csm_generator.generate(
text=text,
speaker=1, # Bot is speaker 1
context=context_segments,
max_audio_length_ms=10000, # 10 seconds max
temperature=0.9,
topk=50
)
# Store the full audio for conversation history
bot_segment = Segment(
text=text,
speaker=1, # Bot is speaker 1
audio=audio
)
if session_id in conversation_context:
if session_id in conversation_context and conversation_context[session_id]['active_session']:
conversation_context[session_id]['segments'].append(bot_segment)
# Split audio into chunks for streaming
chunk_size = CHUNK_SIZE
# Stream audio in small chunks for more responsive playback
chunk_size = 4800 # 200ms at 24kHz
for i in range(0, len(audio), chunk_size):
if session_id not in active_audio_streams or not active_audio_streams[session_id]['active']:
print("Audio streaming interrupted or session ended")
break
chunk = audio[i:i+chunk_size]
# Convert audio chunk to base64 for streaming
@@ -436,61 +498,33 @@ def generate_and_stream_audio(text, conversation_segments, session_id):
audio_bytes.seek(0)
audio_b64 = base64.b64encode(audio_bytes.read()).decode('utf-8')
# Send the chunk to the client
if session_id in audio_stream_queues:
audio_stream_queues[session_id].put({
'audio': audio_b64,
'is_last': i + chunk_size >= len(audio)
})
else:
# Session was disconnected before we finished generating
break
# Signal the end of streaming if queue still exists
if session_id in audio_stream_queues:
# Add an empty chunk as a sentinel to signal end of streaming
audio_stream_queues[session_id].put(None)
# Send chunk to client
socketio.emit('ai_stream_data', {
'audio': audio_b64,
'is_last': i + chunk_size >= len(audio)
}, room=session_id)
# Simulate real-time speech by adding a small delay
# Remove this in production for faster response
time.sleep(0.15) # Slight delay for more natural timing
# Signal end of stream
if session_id in active_audio_streams and active_audio_streams[session_id]['active']:
socketio.emit('ai_stream_end', {}, room=session_id)
active_audio_streams[session_id]['active'] = False
except Exception as e:
print(f"Error generating or streaming audio: {e}")
# Send error message to client
if session_id in conversation_context:
if session_id in conversation_context and conversation_context[session_id]['active_session']:
socketio.emit('error', {
'message': f'Error generating audio: {str(e)}'
}, room=session_id)
# Send a final message to unblock the client
if session_id in audio_stream_queues:
audio_stream_queues[session_id].put(None)
@socketio.on('request_audio_chunk')
def handle_request_audio_chunk():
"""Send the next audio chunk in the queue to the client"""
session_id = request.sid
if session_id not in audio_stream_queues:
emit('error', {'message': 'No audio stream available'})
return
# Get the next chunk or wait for it to be available
try:
if not audio_stream_queues[session_id].empty():
chunk = audio_stream_queues[session_id].get(block=False)
# If chunk is None, we're done streaming
if chunk is None:
emit('end_streaming')
# Clean up the queue
if session_id in audio_stream_queues:
del audio_stream_queues[session_id]
else:
emit('audio_chunk', chunk)
else:
# If the queue is empty but we're still generating, tell client to wait
emit('wait_for_chunk')
except Exception as e:
print(f"Error sending audio chunk: {e}")
emit('error', {'message': f'Error streaming audio: {str(e)}'})
# Signal stream end to unblock client
socketio.emit('ai_stream_end', {}, room=session_id)
if session_id in active_audio_streams:
active_audio_streams[session_id]['active'] = False
if __name__ == '__main__':
# Ensure the existing index.html file is in the correct location
@@ -500,10 +534,10 @@ if __name__ == '__main__':
if os.path.exists('index.html') and not os.path.exists('templates/index.html'):
os.rename('index.html', 'templates/index.html')
# Load models asynchronously before starting the server
# Load models before starting the server
print("Starting model loading...")
load_models()
# Start the server
# Start the server with eventlet for better WebSocket performance
print("Starting Flask SocketIO server...")
socketio.run(app, host='0.0.0.0', port=5000, debug=False)