Demo Fixes 19
This commit is contained in:
1060
Backend/index.html
1060
Backend/index.html
File diff suppressed because it is too large
Load Diff
@@ -14,8 +14,8 @@ import huggingface_hub
|
|||||||
from generator import load_csm_1b, Segment
|
from generator import load_csm_1b, Segment
|
||||||
import threading
|
import threading
|
||||||
import queue
|
import queue
|
||||||
from flask import stream_with_context, Response
|
import asyncio
|
||||||
import time
|
import json
|
||||||
|
|
||||||
# Configure environment with longer timeouts
|
# Configure environment with longer timeouts
|
||||||
os.environ["HF_HUB_DOWNLOAD_TIMEOUT"] = "600" # 10 minutes timeout for downloads
|
os.environ["HF_HUB_DOWNLOAD_TIMEOUT"] = "600" # 10 minutes timeout for downloads
|
||||||
@@ -26,7 +26,7 @@ os.makedirs("models", exist_ok=True)
|
|||||||
|
|
||||||
app = Flask(__name__)
|
app = Flask(__name__)
|
||||||
app.config['SECRET_KEY'] = 'your-secret-key'
|
app.config['SECRET_KEY'] = 'your-secret-key'
|
||||||
socketio = SocketIO(app, cors_allowed_origins="*")
|
socketio = SocketIO(app, cors_allowed_origins="*", async_mode='eventlet')
|
||||||
|
|
||||||
# Explicitly check for CUDA and print more detailed info
|
# Explicitly check for CUDA and print more detailed info
|
||||||
print("\n=== CUDA Information ===")
|
print("\n=== CUDA Information ===")
|
||||||
@@ -128,8 +128,7 @@ def load_models():
|
|||||||
|
|
||||||
# Store conversation context
|
# Store conversation context
|
||||||
conversation_context = {} # session_id -> context
|
conversation_context = {} # session_id -> context
|
||||||
CHUNK_SIZE = 24000 # Number of audio samples per chunk (1 second at 24kHz)
|
active_audio_streams = {} # session_id -> stream status
|
||||||
audio_stream_queues = {} # session_id -> queue for audio chunks
|
|
||||||
|
|
||||||
@app.route('/')
|
@app.route('/')
|
||||||
def index():
|
def index():
|
||||||
@@ -143,9 +142,14 @@ def handle_connect():
|
|||||||
'speakers': [0, 1], # 0 = user, 1 = bot
|
'speakers': [0, 1], # 0 = user, 1 = bot
|
||||||
'audio_buffer': deque(maxlen=10), # Store recent audio chunks
|
'audio_buffer': deque(maxlen=10), # Store recent audio chunks
|
||||||
'is_speaking': False,
|
'is_speaking': False,
|
||||||
'silence_start': None
|
'last_activity': time.time(),
|
||||||
|
'active_session': True,
|
||||||
|
'transcription_buffer': [] # For real-time transcription
|
||||||
}
|
}
|
||||||
emit('ready', {'message': 'Connection established'})
|
emit('ready', {
|
||||||
|
'message': 'Connection established',
|
||||||
|
'sample_rate': getattr(csm_generator, 'sample_rate', 24000) if csm_generator else 24000
|
||||||
|
})
|
||||||
|
|
||||||
@socketio.on('disconnect')
|
@socketio.on('disconnect')
|
||||||
def handle_disconnect():
|
def handle_disconnect():
|
||||||
@@ -154,56 +158,130 @@ def handle_disconnect():
|
|||||||
|
|
||||||
# Clean up resources
|
# Clean up resources
|
||||||
if session_id in conversation_context:
|
if session_id in conversation_context:
|
||||||
|
conversation_context[session_id]['active_session'] = False
|
||||||
del conversation_context[session_id]
|
del conversation_context[session_id]
|
||||||
|
|
||||||
if session_id in audio_stream_queues:
|
if session_id in active_audio_streams:
|
||||||
del audio_stream_queues[session_id]
|
active_audio_streams[session_id]['active'] = False
|
||||||
|
del active_audio_streams[session_id]
|
||||||
|
|
||||||
@socketio.on('start_speaking')
|
@socketio.on('audio_stream')
|
||||||
def handle_start_speaking():
|
def handle_audio_stream(data):
|
||||||
if request.sid in conversation_context:
|
"""Handle incoming audio stream from client"""
|
||||||
conversation_context[request.sid]['is_speaking'] = True
|
session_id = request.sid
|
||||||
conversation_context[request.sid]['audio_buffer'].clear()
|
|
||||||
print(f"User {request.sid} started speaking")
|
|
||||||
|
|
||||||
@socketio.on('audio_chunk')
|
if session_id not in conversation_context:
|
||||||
def handle_audio_chunk(data):
|
|
||||||
if request.sid not in conversation_context:
|
|
||||||
return
|
return
|
||||||
|
|
||||||
context = conversation_context[request.sid]
|
context = conversation_context[session_id]
|
||||||
|
context['last_activity'] = time.time()
|
||||||
|
|
||||||
|
# Process different stream events
|
||||||
|
if data.get('event') == 'start':
|
||||||
|
# Client is starting to send audio
|
||||||
|
context['is_speaking'] = True
|
||||||
|
context['audio_buffer'].clear()
|
||||||
|
context['transcription_buffer'] = []
|
||||||
|
print(f"User {session_id} started streaming audio")
|
||||||
|
|
||||||
|
# If AI was speaking, interrupt it
|
||||||
|
if session_id in active_audio_streams and active_audio_streams[session_id]['active']:
|
||||||
|
active_audio_streams[session_id]['active'] = False
|
||||||
|
emit('ai_stream_interrupt', {}, room=session_id)
|
||||||
|
|
||||||
|
elif data.get('event') == 'data':
|
||||||
|
# Audio data received
|
||||||
|
if not context['is_speaking']:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Decode audio chunk
|
||||||
|
try:
|
||||||
|
audio_data = base64.b64decode(data.get('audio', ''))
|
||||||
|
if not audio_data:
|
||||||
|
return
|
||||||
|
|
||||||
# Decode audio data
|
|
||||||
audio_data = base64.b64decode(data['audio'])
|
|
||||||
audio_numpy = np.frombuffer(audio_data, dtype=np.float32)
|
audio_numpy = np.frombuffer(audio_data, dtype=np.float32)
|
||||||
|
|
||||||
|
# Apply a simple noise gate
|
||||||
|
if np.mean(np.abs(audio_numpy)) < 0.01: # Very quiet
|
||||||
|
return
|
||||||
|
|
||||||
audio_tensor = torch.tensor(audio_numpy)
|
audio_tensor = torch.tensor(audio_numpy)
|
||||||
|
|
||||||
# Add to buffer
|
# Add to audio buffer
|
||||||
context['audio_buffer'].append(audio_tensor)
|
context['audio_buffer'].append(audio_tensor)
|
||||||
|
|
||||||
# Check for silence to detect end of speech
|
# Real-time transcription (periodic)
|
||||||
if context['is_speaking'] and is_silence(audio_tensor):
|
if len(context['audio_buffer']) % 3 == 0: # Process every 3 chunks
|
||||||
if context['silence_start'] is None:
|
threading.Thread(
|
||||||
context['silence_start'] = time.time()
|
target=process_realtime_transcription,
|
||||||
elif time.time() - context['silence_start'] > 1.0: # 1 second of silence
|
args=(session_id,),
|
||||||
|
daemon=True
|
||||||
|
).start()
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error processing audio chunk: {e}")
|
||||||
|
|
||||||
|
elif data.get('event') == 'end':
|
||||||
|
# Client has finished sending audio
|
||||||
|
context['is_speaking'] = False
|
||||||
|
|
||||||
|
if len(context['audio_buffer']) > 0:
|
||||||
# Process the complete utterance
|
# Process the complete utterance
|
||||||
process_user_utterance(request.sid)
|
threading.Thread(
|
||||||
else:
|
target=process_complete_utterance,
|
||||||
context['silence_start'] = None
|
args=(session_id,),
|
||||||
|
daemon=True
|
||||||
|
).start()
|
||||||
|
|
||||||
@socketio.on('stop_speaking')
|
print(f"User {session_id} stopped streaming audio")
|
||||||
def handle_stop_speaking():
|
|
||||||
if request.sid in conversation_context:
|
|
||||||
conversation_context[request.sid]['is_speaking'] = False
|
|
||||||
process_user_utterance(request.sid)
|
|
||||||
print(f"User {request.sid} stopped speaking")
|
|
||||||
|
|
||||||
def is_silence(audio_tensor, threshold=0.02):
|
def process_realtime_transcription(session_id):
|
||||||
"""Check if an audio chunk is silence based on amplitude threshold"""
|
"""Process incoming audio for real-time transcription"""
|
||||||
return torch.mean(torch.abs(audio_tensor)) < threshold
|
if session_id not in conversation_context or not conversation_context[session_id]['active_session']:
|
||||||
|
return
|
||||||
|
|
||||||
def process_user_utterance(session_id):
|
context = conversation_context[session_id]
|
||||||
|
|
||||||
|
if not context['audio_buffer'] or not context['is_speaking']:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Combine current buffer for transcription
|
||||||
|
buffer_copy = list(context['audio_buffer'])
|
||||||
|
if not buffer_copy:
|
||||||
|
return
|
||||||
|
|
||||||
|
full_audio = torch.cat(buffer_copy, dim=0)
|
||||||
|
|
||||||
|
# Save audio to temporary WAV file for transcription
|
||||||
|
temp_audio_path = f"temp_rt_{session_id}.wav"
|
||||||
|
torchaudio.save(
|
||||||
|
temp_audio_path,
|
||||||
|
full_audio.unsqueeze(0),
|
||||||
|
44100 # Assuming 44.1kHz from client
|
||||||
|
)
|
||||||
|
|
||||||
|
# Transcribe with Whisper if available
|
||||||
|
if whisper_model is not None:
|
||||||
|
segments, _ = whisper_model.transcribe(temp_audio_path, beam_size=5)
|
||||||
|
text = " ".join([segment.text for segment in segments])
|
||||||
|
|
||||||
|
if text.strip():
|
||||||
|
context['transcription_buffer'].append(text)
|
||||||
|
# Send partial transcription to client
|
||||||
|
emit('partial_transcription', {'text': text}, room=session_id)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error in realtime transcription: {e}")
|
||||||
|
finally:
|
||||||
|
# Clean up
|
||||||
|
if os.path.exists(temp_audio_path):
|
||||||
|
os.remove(temp_audio_path)
|
||||||
|
|
||||||
|
def process_complete_utterance(session_id):
|
||||||
"""Process completed user utterance, generate response and stream audio back"""
|
"""Process completed user utterance, generate response and stream audio back"""
|
||||||
|
if session_id not in conversation_context or not conversation_context[session_id]['active_session']:
|
||||||
|
return
|
||||||
|
|
||||||
context = conversation_context[session_id]
|
context = conversation_context[session_id]
|
||||||
|
|
||||||
if not context['audio_buffer']:
|
if not context['audio_buffer']:
|
||||||
@@ -212,8 +290,6 @@ def process_user_utterance(session_id):
|
|||||||
# Combine audio chunks
|
# Combine audio chunks
|
||||||
full_audio = torch.cat(list(context['audio_buffer']), dim=0)
|
full_audio = torch.cat(list(context['audio_buffer']), dim=0)
|
||||||
context['audio_buffer'].clear()
|
context['audio_buffer'].clear()
|
||||||
context['is_speaking'] = False
|
|
||||||
context['silence_start'] = None
|
|
||||||
|
|
||||||
# Save audio to temporary WAV file for transcription
|
# Save audio to temporary WAV file for transcription
|
||||||
temp_audio_path = f"temp_audio_{session_id}.wav"
|
temp_audio_path = f"temp_audio_{session_id}.wav"
|
||||||
@@ -255,23 +331,23 @@ def process_user_utterance(session_id):
|
|||||||
|
|
||||||
# Generate and stream audio response if CSM is available
|
# Generate and stream audio response if CSM is available
|
||||||
if csm_generator is not None:
|
if csm_generator is not None:
|
||||||
# Set up streaming queue for this session
|
# Create stream state object
|
||||||
if session_id not in audio_stream_queues:
|
active_audio_streams[session_id] = {
|
||||||
audio_stream_queues[session_id] = queue.Queue()
|
'active': True,
|
||||||
else:
|
'text': bot_response
|
||||||
# Clear any existing items in the queue
|
}
|
||||||
while not audio_stream_queues[session_id].empty():
|
|
||||||
audio_stream_queues[session_id].get()
|
|
||||||
|
|
||||||
# Start audio generation in a separate thread to not block the server
|
# Send initial response to prepare client
|
||||||
|
emit('ai_stream_start', {
|
||||||
|
'text': bot_response
|
||||||
|
}, room=session_id)
|
||||||
|
|
||||||
|
# Start audio generation in a separate thread
|
||||||
threading.Thread(
|
threading.Thread(
|
||||||
target=generate_and_stream_audio,
|
target=generate_and_stream_audio_realtime,
|
||||||
args=(bot_response, context['segments'], session_id),
|
args=(bot_response, context['segments'], session_id),
|
||||||
daemon=True
|
daemon=True
|
||||||
).start()
|
).start()
|
||||||
|
|
||||||
# Initial response with text
|
|
||||||
emit('start_streaming_response', {'text': bot_response}, room=session_id)
|
|
||||||
else:
|
else:
|
||||||
# Send text-only response if audio generation isn't available
|
# Send text-only response if audio generation isn't available
|
||||||
emit('text_response', {'text': bot_response}, room=session_id)
|
emit('text_response', {'text': bot_response}, room=session_id)
|
||||||
@@ -378,8 +454,11 @@ def fallback_response(user_text):
|
|||||||
else:
|
else:
|
||||||
return "I understand you said something about that. Unfortunately, I'm running in fallback mode with limited capabilities. Please try again later when the main model is available."
|
return "I understand you said something about that. Unfortunately, I'm running in fallback mode with limited capabilities. Please try again later when the main model is available."
|
||||||
|
|
||||||
def generate_audio_response(text, conversation_segments):
|
def generate_and_stream_audio_realtime(text, conversation_segments, session_id):
|
||||||
"""Generate audio response using CSM"""
|
"""Generate audio response using CSM and stream it in real-time to client"""
|
||||||
|
if session_id not in active_audio_streams or not active_audio_streams[session_id]['active']:
|
||||||
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Use the last few conversation segments as context
|
# Use the last few conversation segments as context
|
||||||
context_segments = conversation_segments[-4:] if len(conversation_segments) > 4 else conversation_segments
|
context_segments = conversation_segments[-4:] if len(conversation_segments) > 4 else conversation_segments
|
||||||
@@ -394,40 +473,23 @@ def generate_audio_response(text, conversation_segments):
|
|||||||
topk=50
|
topk=50
|
||||||
)
|
)
|
||||||
|
|
||||||
return audio
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error generating audio: {e}")
|
|
||||||
# Return silence as fallback
|
|
||||||
return torch.zeros(csm_generator.sample_rate * 3) # 3 seconds of silence
|
|
||||||
|
|
||||||
def generate_and_stream_audio(text, conversation_segments, session_id):
|
|
||||||
"""Generate audio response using CSM and stream it in chunks"""
|
|
||||||
try:
|
|
||||||
# Use the last few conversation segments as context
|
|
||||||
context_segments = conversation_segments[-4:] if len(conversation_segments) > 4 else conversation_segments
|
|
||||||
|
|
||||||
# Generate full audio for bot response
|
|
||||||
audio = csm_generator.generate(
|
|
||||||
text=text,
|
|
||||||
speaker=1, # Bot is speaker 1
|
|
||||||
context=context_segments,
|
|
||||||
max_audio_length_ms=10000, # 10 seconds max
|
|
||||||
temperature=0.9,
|
|
||||||
topk=50
|
|
||||||
)
|
|
||||||
|
|
||||||
# Store the full audio for conversation history
|
# Store the full audio for conversation history
|
||||||
bot_segment = Segment(
|
bot_segment = Segment(
|
||||||
text=text,
|
text=text,
|
||||||
speaker=1, # Bot is speaker 1
|
speaker=1, # Bot is speaker 1
|
||||||
audio=audio
|
audio=audio
|
||||||
)
|
)
|
||||||
if session_id in conversation_context:
|
if session_id in conversation_context and conversation_context[session_id]['active_session']:
|
||||||
conversation_context[session_id]['segments'].append(bot_segment)
|
conversation_context[session_id]['segments'].append(bot_segment)
|
||||||
|
|
||||||
# Split audio into chunks for streaming
|
# Stream audio in small chunks for more responsive playback
|
||||||
chunk_size = CHUNK_SIZE
|
chunk_size = 4800 # 200ms at 24kHz
|
||||||
|
|
||||||
for i in range(0, len(audio), chunk_size):
|
for i in range(0, len(audio), chunk_size):
|
||||||
|
if session_id not in active_audio_streams or not active_audio_streams[session_id]['active']:
|
||||||
|
print("Audio streaming interrupted or session ended")
|
||||||
|
break
|
||||||
|
|
||||||
chunk = audio[i:i+chunk_size]
|
chunk = audio[i:i+chunk_size]
|
||||||
|
|
||||||
# Convert audio chunk to base64 for streaming
|
# Convert audio chunk to base64 for streaming
|
||||||
@@ -436,61 +498,33 @@ def generate_and_stream_audio(text, conversation_segments, session_id):
|
|||||||
audio_bytes.seek(0)
|
audio_bytes.seek(0)
|
||||||
audio_b64 = base64.b64encode(audio_bytes.read()).decode('utf-8')
|
audio_b64 = base64.b64encode(audio_bytes.read()).decode('utf-8')
|
||||||
|
|
||||||
# Send the chunk to the client
|
# Send chunk to client
|
||||||
if session_id in audio_stream_queues:
|
socketio.emit('ai_stream_data', {
|
||||||
audio_stream_queues[session_id].put({
|
|
||||||
'audio': audio_b64,
|
'audio': audio_b64,
|
||||||
'is_last': i + chunk_size >= len(audio)
|
'is_last': i + chunk_size >= len(audio)
|
||||||
})
|
}, room=session_id)
|
||||||
else:
|
|
||||||
# Session was disconnected before we finished generating
|
|
||||||
break
|
|
||||||
|
|
||||||
# Signal the end of streaming if queue still exists
|
# Simulate real-time speech by adding a small delay
|
||||||
if session_id in audio_stream_queues:
|
# Remove this in production for faster response
|
||||||
# Add an empty chunk as a sentinel to signal end of streaming
|
time.sleep(0.15) # Slight delay for more natural timing
|
||||||
audio_stream_queues[session_id].put(None)
|
|
||||||
|
# Signal end of stream
|
||||||
|
if session_id in active_audio_streams and active_audio_streams[session_id]['active']:
|
||||||
|
socketio.emit('ai_stream_end', {}, room=session_id)
|
||||||
|
active_audio_streams[session_id]['active'] = False
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error generating or streaming audio: {e}")
|
print(f"Error generating or streaming audio: {e}")
|
||||||
# Send error message to client
|
# Send error message to client
|
||||||
if session_id in conversation_context:
|
if session_id in conversation_context and conversation_context[session_id]['active_session']:
|
||||||
socketio.emit('error', {
|
socketio.emit('error', {
|
||||||
'message': f'Error generating audio: {str(e)}'
|
'message': f'Error generating audio: {str(e)}'
|
||||||
}, room=session_id)
|
}, room=session_id)
|
||||||
|
|
||||||
# Send a final message to unblock the client
|
# Signal stream end to unblock client
|
||||||
if session_id in audio_stream_queues:
|
socketio.emit('ai_stream_end', {}, room=session_id)
|
||||||
audio_stream_queues[session_id].put(None)
|
if session_id in active_audio_streams:
|
||||||
|
active_audio_streams[session_id]['active'] = False
|
||||||
@socketio.on('request_audio_chunk')
|
|
||||||
def handle_request_audio_chunk():
|
|
||||||
"""Send the next audio chunk in the queue to the client"""
|
|
||||||
session_id = request.sid
|
|
||||||
|
|
||||||
if session_id not in audio_stream_queues:
|
|
||||||
emit('error', {'message': 'No audio stream available'})
|
|
||||||
return
|
|
||||||
|
|
||||||
# Get the next chunk or wait for it to be available
|
|
||||||
try:
|
|
||||||
if not audio_stream_queues[session_id].empty():
|
|
||||||
chunk = audio_stream_queues[session_id].get(block=False)
|
|
||||||
|
|
||||||
# If chunk is None, we're done streaming
|
|
||||||
if chunk is None:
|
|
||||||
emit('end_streaming')
|
|
||||||
# Clean up the queue
|
|
||||||
if session_id in audio_stream_queues:
|
|
||||||
del audio_stream_queues[session_id]
|
|
||||||
else:
|
|
||||||
emit('audio_chunk', chunk)
|
|
||||||
else:
|
|
||||||
# If the queue is empty but we're still generating, tell client to wait
|
|
||||||
emit('wait_for_chunk')
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error sending audio chunk: {e}")
|
|
||||||
emit('error', {'message': f'Error streaming audio: {str(e)}'})
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
# Ensure the existing index.html file is in the correct location
|
# Ensure the existing index.html file is in the correct location
|
||||||
@@ -500,10 +534,10 @@ if __name__ == '__main__':
|
|||||||
if os.path.exists('index.html') and not os.path.exists('templates/index.html'):
|
if os.path.exists('index.html') and not os.path.exists('templates/index.html'):
|
||||||
os.rename('index.html', 'templates/index.html')
|
os.rename('index.html', 'templates/index.html')
|
||||||
|
|
||||||
# Load models asynchronously before starting the server
|
# Load models before starting the server
|
||||||
print("Starting model loading...")
|
print("Starting model loading...")
|
||||||
load_models()
|
load_models()
|
||||||
|
|
||||||
# Start the server
|
# Start the server with eventlet for better WebSocket performance
|
||||||
print("Starting Flask SocketIO server...")
|
print("Starting Flask SocketIO server...")
|
||||||
socketio.run(app, host='0.0.0.0', port=5000, debug=False)
|
socketio.run(app, host='0.0.0.0', port=5000, debug=False)
|
||||||
Reference in New Issue
Block a user