import os import base64 import json import torch import torchaudio import numpy as np import whisperx from io import BytesIO from typing import List, Dict, Any, Optional from flask import Flask, request, send_from_directory, Response from flask_cors import CORS from flask_socketio import SocketIO, emit, disconnect from generator import load_csm_1b, Segment import time import gc from collections import deque from threading import Lock # Select device if torch.cuda.is_available(): device = "cuda" else: device = "cpu" print(f"Using device: {device}") # Initialize the model generator = load_csm_1b(device=device) # Initialize WhisperX for ASR print("Loading WhisperX model...") # Use a smaller model for faster response times asr_model = whisperx.load_model("medium", device, compute_type="float16") print("WhisperX model loaded!") # Silence detection parameters SILENCE_THRESHOLD = 0.01 # Adjust based on your audio normalization SILENCE_DURATION_SEC = 1.0 # How long silence must persist # Define the base directory base_dir = os.path.dirname(os.path.abspath(__file__)) static_dir = os.path.join(base_dir, "static") os.makedirs(static_dir, exist_ok=True) # Setup Flask app = Flask(__name__) CORS(app) socketio = SocketIO(app, cors_allowed_origins="*", async_mode='eventlet') # Socket connection management thread = None thread_lock = Lock() active_clients = {} # Map client_id to client context # Helper function to convert audio data def decode_audio_data(audio_data: str) -> torch.Tensor: """Decode base64 audio data to a torch tensor""" try: # Extract the actual base64 content if ',' in audio_data: audio_data = audio_data.split(',')[1] # Decode base64 audio data binary_data = base64.b64decode(audio_data) # Load audio from binary data with BytesIO(binary_data) as temp_file: audio_tensor, sample_rate = torchaudio.load(temp_file, format="wav") # Resample if needed if sample_rate != generator.sample_rate: audio_tensor = torchaudio.functional.resample( audio_tensor.squeeze(0), orig_freq=sample_rate, new_freq=generator.sample_rate ) else: audio_tensor = audio_tensor.squeeze(0) return audio_tensor except Exception as e: print(f"Error decoding audio: {str(e)}") # Return a small silent audio segment as fallback return torch.zeros(generator.sample_rate // 2) # 0.5 seconds of silence def encode_audio_data(audio_tensor: torch.Tensor) -> str: """Encode torch tensor audio to base64 string""" buf = BytesIO() torchaudio.save(buf, audio_tensor.unsqueeze(0).cpu(), generator.sample_rate, format="wav") buf.seek(0) audio_base64 = base64.b64encode(buf.read()).decode('utf-8') return f"data:audio/wav;base64,{audio_base64}" def transcribe_audio(audio_tensor: torch.Tensor) -> str: """Transcribe audio using WhisperX""" try: # Save the tensor to a temporary file temp_path = os.path.join(base_dir, "temp_audio.wav") torchaudio.save(temp_path, audio_tensor.unsqueeze(0).cpu(), generator.sample_rate) # Load and transcribe the audio audio = whisperx.load_audio(temp_path) result = asr_model.transcribe(audio, batch_size=16) # Clean up if os.path.exists(temp_path): os.remove(temp_path) # Get the transcription text if result["segments"] and len(result["segments"]) > 0: # Combine all segments transcription = " ".join([segment["text"] for segment in result["segments"]]) return transcription.strip() else: return "" except Exception as e: print(f"Error in transcription: {str(e)}") if os.path.exists("temp_audio.wav"): os.remove("temp_audio.wav") return "" def generate_response(text: str, conversation_history: List[Segment]) -> str: """Generate a contextual response based on the transcribed text""" # Simple response logic - can be replaced with a more sophisticated LLM in the future responses = { "hello": "Hello there! How are you doing today?", "how are you": "I'm doing well, thanks for asking! How about you?", "what is your name": "I'm Sesame, your voice assistant. How can I help you?", "bye": "Goodbye! It was nice chatting with you.", "thank you": "You're welcome! Is there anything else I can help with?", "weather": "I don't have real-time weather data, but I hope it's nice where you are!", "help": "I can chat with you using natural voice. Just speak normally and I'll respond.", } text_lower = text.lower() # Check for matching keywords for key, response in responses.items(): if key in text_lower: return response # Default responses based on text length if not text: return "I didn't catch that. Could you please repeat?" elif len(text) < 10: return "Thanks for your message. Could you elaborate a bit more?" else: return f"I understand you said '{text}'. That's interesting! Can you tell me more about that?" # Flask routes for serving static content @app.route('/') def index(): return send_from_directory(base_dir, 'index.html') @app.route('/favicon.ico') def favicon(): if os.path.exists(os.path.join(static_dir, 'favicon.ico')): return send_from_directory(static_dir, 'favicon.ico') return Response(status=204) @app.route('/voice-chat.js') def voice_chat_js(): return send_from_directory(base_dir, 'voice-chat.js') @app.route('/static/') def serve_static(path): return send_from_directory(static_dir, path) # Socket.IO event handlers @socketio.on('connect') def handle_connect(): client_id = request.sid print(f"Client connected: {client_id}") # Initialize client context active_clients[client_id] = { 'context_segments': [], 'streaming_buffer': [], 'is_streaming': False, 'is_silence': False, 'last_active_time': time.time(), 'energy_window': deque(maxlen=10) } emit('status', {'type': 'connected', 'message': 'Connected to server'}) @socketio.on('disconnect') def handle_disconnect(): client_id = request.sid if client_id in active_clients: del active_clients[client_id] print(f"Client disconnected: {client_id}") @socketio.on('generate') def handle_generate(data): client_id = request.sid if client_id not in active_clients: emit('error', {'message': 'Client not registered'}) return try: text = data.get('text', '') speaker_id = data.get('speaker', 0) print(f"Generating audio for: '{text}' with speaker {speaker_id}") # Generate audio response audio_tensor = generator.generate( text=text, speaker=speaker_id, context=active_clients[client_id]['context_segments'], max_audio_length_ms=10_000, ) # Add to conversation context active_clients[client_id]['context_segments'].append( Segment(text=text, speaker=speaker_id, audio=audio_tensor) ) # Convert audio to base64 and send back to client audio_base64 = encode_audio_data(audio_tensor) emit('audio_response', { 'type': 'audio_response', 'audio': audio_base64 }) except Exception as e: print(f"Error generating audio: {str(e)}") emit('error', { 'type': 'error', 'message': f"Error generating audio: {str(e)}" }) @socketio.on('add_to_context') def handle_add_to_context(data): client_id = request.sid if client_id not in active_clients: emit('error', {'message': 'Client not registered'}) return try: text = data.get('text', '') speaker_id = data.get('speaker', 0) audio_data = data.get('audio', '') # Convert received audio to tensor audio_tensor = decode_audio_data(audio_data) # Add to conversation context active_clients[client_id]['context_segments'].append( Segment(text=text, speaker=speaker_id, audio=audio_tensor) ) emit('context_updated', { 'type': 'context_updated', 'message': 'Audio added to context' }) except Exception as e: print(f"Error adding to context: {str(e)}") emit('error', { 'type': 'error', 'message': f"Error processing audio: {str(e)}" }) @socketio.on('clear_context') def handle_clear_context(): client_id = request.sid if client_id in active_clients: active_clients[client_id]['context_segments'] = [] emit('context_updated', { 'type': 'context_updated', 'message': 'Context cleared' }) @socketio.on('stream_audio') def handle_stream_audio(data): client_id = request.sid if client_id not in active_clients: emit('error', {'message': 'Client not registered'}) return client = active_clients[client_id] try: speaker_id = data.get('speaker', 0) audio_data = data.get('audio', '') # Convert received audio to tensor audio_chunk = decode_audio_data(audio_data) # Start streaming mode if not already started if not client['is_streaming']: client['is_streaming'] = True client['streaming_buffer'] = [] client['energy_window'].clear() client['is_silence'] = False client['last_active_time'] = time.time() print(f"[{client_id}] Streaming started with speaker ID: {speaker_id}") emit('streaming_status', { 'type': 'streaming_status', 'status': 'started' }) # Calculate audio energy for silence detection chunk_energy = torch.mean(torch.abs(audio_chunk)).item() client['energy_window'].append(chunk_energy) avg_energy = sum(client['energy_window']) / len(client['energy_window']) # Check if audio is silent current_silence = avg_energy < SILENCE_THRESHOLD # Track silence transition if not client['is_silence'] and current_silence: # Transition to silence client['is_silence'] = True client['last_active_time'] = time.time() elif client['is_silence'] and not current_silence: # User started talking again client['is_silence'] = False # Add chunk to buffer regardless of silence state client['streaming_buffer'].append(audio_chunk) # Check if silence has persisted long enough to consider "stopped talking" silence_elapsed = time.time() - client['last_active_time'] if client['is_silence'] and silence_elapsed >= SILENCE_DURATION_SEC and len(client['streaming_buffer']) > 0: # User has stopped talking - process the collected audio print(f"[{client_id}] Processing audio after {silence_elapsed:.2f}s of silence") full_audio = torch.cat(client['streaming_buffer'], dim=0) # Process with WhisperX speech-to-text print(f"[{client_id}] Starting transcription with WhisperX...") transcribed_text = transcribe_audio(full_audio) # Log the transcription print(f"[{client_id}] Transcribed text: '{transcribed_text}'") # Add to conversation context if transcribed_text: user_segment = Segment(text=transcribed_text, speaker=speaker_id, audio=full_audio) client['context_segments'].append(user_segment) # Generate a contextual response response_text = generate_response(transcribed_text, client['context_segments']) # Send the transcribed text to client emit('transcription', { 'type': 'transcription', 'text': transcribed_text }) # Generate audio for the response audio_tensor = generator.generate( text=response_text, speaker=1 if speaker_id == 0 else 0, # Use opposite speaker context=client['context_segments'], max_audio_length_ms=10_000, ) # Add response to context ai_segment = Segment( text=response_text, speaker=1 if speaker_id == 0 else 0, audio=audio_tensor ) client['context_segments'].append(ai_segment) # Convert audio to base64 and send back to client audio_base64 = encode_audio_data(audio_tensor) emit('audio_response', { 'type': 'audio_response', 'text': response_text, 'audio': audio_base64 }) else: # If transcription failed, send a generic response emit('error', { 'type': 'error', 'message': "Sorry, I couldn't understand what you said. Could you try again?" }) # Clear buffer and reset silence detection client['streaming_buffer'] = [] client['energy_window'].clear() client['is_silence'] = False client['last_active_time'] = time.time() # If buffer gets too large without silence, process it anyway elif len(client['streaming_buffer']) >= 30: # ~6 seconds of audio at 5 chunks/sec full_audio = torch.cat(client['streaming_buffer'], dim=0) # Process with WhisperX speech-to-text transcribed_text = transcribe_audio(full_audio) if transcribed_text: client['context_segments'].append( Segment(text=transcribed_text, speaker=speaker_id, audio=full_audio) ) # Send the transcribed text to client emit('transcription', { 'type': 'transcription', 'text': transcribed_text + " (processing continued speech...)" }) client['streaming_buffer'] = [] except Exception as e: import traceback traceback.print_exc() print(f"Error processing streaming audio: {str(e)}") emit('error', { 'type': 'error', 'message': f"Error processing streaming audio: {str(e)}" }) @socketio.on('stop_streaming') def handle_stop_streaming(data): client_id = request.sid if client_id not in active_clients: return client = active_clients[client_id] client['is_streaming'] = False if client['streaming_buffer'] and len(client['streaming_buffer']) > 5: # Process any remaining audio in the buffer full_audio = torch.cat(client['streaming_buffer'], dim=0) # Process with WhisperX speech-to-text transcribed_text = transcribe_audio(full_audio) if transcribed_text: client['context_segments'].append( Segment(text=transcribed_text, speaker=data.get("speaker", 0), audio=full_audio) ) # Send the transcribed text to client emit('transcription', { 'type': 'transcription', 'text': transcribed_text }) client['streaming_buffer'] = [] emit('streaming_status', { 'type': 'streaming_status', 'status': 'stopped' }) if __name__ == "__main__": print(f"\n{'='*60}") print(f"🔊 Sesame AI Voice Chat Server (Flask Implementation)") print(f"{'='*60}") print(f"📡 Server Information:") print(f" - Local URL: http://localhost:5000") print(f" - Network URL: http://:5000") print(f" - WebSocket: ws://:5000/socket.io") print(f"{'='*60}") print(f"💡 To make this server public:") print(f" 1. Ensure port 5000 is open in your firewall") print(f" 2. Set up port forwarding on your router to port 5000") print(f" 3. Or use a service like ngrok with: ngrok http 5000") print(f"{'='*60}") print(f"🌐 Device: {device.upper()}") print(f"🧠 Models loaded: Sesame CSM + WhisperX ({asr_model.device})") print(f"🔧 Serving from: {os.path.join(base_dir, 'index.html')}") print(f"{'='*60}") print(f"Ready to receive connections! Press Ctrl+C to stop the server.\n") socketio.run(app, host="0.0.0.0", port=5000, debug=False)