import os import base64 import json import time import math import gc import logging import numpy as np import torch import torchaudio from io import BytesIO from typing import List, Dict, Any, Optional from flask import Flask, request, send_from_directory, Response from flask_cors import CORS from flask_socketio import SocketIO, emit, disconnect from generator import load_csm_1b, Segment from collections import deque from threading import Lock from transformers import pipeline from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) logger = logging.getLogger("sesame-server") # Determine best compute device if torch.backends.mps.is_available(): device = "mps" elif torch.cuda.is_available(): try: # Test CUDA functionality torch.rand(10, device="cuda") torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True torch.backends.cudnn.benchmark = True device = "cuda" logger.info("CUDA is fully functional") except Exception as e: logger.warning(f"CUDA available but not working correctly: {e}") device = "cpu" else: device = "cpu" logger.info("Using CPU") # Constants and Configuration SILENCE_THRESHOLD = 0.01 SILENCE_DURATION_SEC = 0.75 MAX_BUFFER_SIZE = 30 # Maximum chunks to buffer before processing CHUNK_SIZE_MS = 500 # Size of audio chunks when streaming responses # Define the base directory and static files directory base_dir = os.path.dirname(os.path.abspath(__file__)) static_dir = os.path.join(base_dir, "static") os.makedirs(static_dir, exist_ok=True) # Define a simple energy-based speech detector class SpeechDetector: def __init__(self): self.min_speech_energy = 0.01 self.speech_window = 0.2 # seconds def detect_speech(self, audio_tensor, sample_rate): # Calculate frame size based on window size frame_size = int(sample_rate * self.speech_window) # If audio is shorter than frame size, use the entire audio if audio_tensor.shape[0] < frame_size: frames = [audio_tensor] else: # Split audio into frames frames = [audio_tensor[i:i+frame_size] for i in range(0, len(audio_tensor), frame_size)] # Calculate energy per frame energies = [torch.mean(frame**2).item() for frame in frames] # Determine if there's speech based on energy threshold has_speech = any(e > self.min_speech_energy for e in energies) return has_speech speech_detector = SpeechDetector() logger.info("Initialized simple speech detector") # Model Loading Functions def load_speech_models(): """Load speech generation and recognition models""" # Load CSM (existing code) generator = load_csm_1b(device=device) # Load Whisper model for speech recognition try: logger.info(f"Loading speech recognition model on {device}...") speech_recognizer = pipeline("automatic-speech-recognition", model="openai/whisper-small", device=device) logger.info("Speech recognition model loaded successfully") except Exception as e: logger.error(f"Error loading speech recognition model: {e}") speech_recognizer = None return generator, speech_recognizer # Unpack both models generator, speech_recognizer = load_speech_models() # Initialize Llama 3.2 model for conversation responses def load_llm_model(): """Load Llama 3.2 model for generating text responses""" try: logger.info("Loading Llama 3.2 model for conversational responses...") model_id = "meta-llama/Llama-3.2-1B-Instruct" tokenizer = AutoTokenizer.from_pretrained(model_id) # Determine compute device for LLM llm_device = "cpu" # Default to CPU for LLM # Use CUDA if available and there's enough VRAM if device == "cuda" and torch.cuda.is_available(): try: free_mem = torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_allocated(0) # If we have at least 2GB free, use CUDA for LLM if free_mem > 2 * 1024 * 1024 * 1024: llm_device = "cuda" except: pass logger.info(f"Using {llm_device} for Llama 3.2 model") # Load the model with lower precision for efficiency model = AutoModelForCausalLM.from_pretrained( model_id, torch_dtype=torch.float16 if llm_device == "cuda" else torch.float32, device_map=llm_device ) # Create a pipeline for easier inference llm = pipeline( "text-generation", model=model, tokenizer=tokenizer, max_length=512, do_sample=True, temperature=0.7, top_p=0.9, repetition_penalty=1.1 ) logger.info("Llama 3.2 model loaded successfully") return llm except Exception as e: logger.error(f"Error loading Llama 3.2 model: {e}") return None # Load the LLM model llm = load_llm_model() # Set up Flask and Socket.IO app = Flask(__name__) CORS(app) socketio = SocketIO(app, cors_allowed_origins="*", async_mode='eventlet') # Socket connection management thread_lock = Lock() active_clients = {} # Map client_id to client context # Audio Utility Functions def decode_audio_data(audio_data: str) -> torch.Tensor: """Decode base64 audio data to a torch tensor with improved error handling""" try: # Skip empty audio data if not audio_data or len(audio_data) < 100: logger.warning("Empty or too short audio data received") return torch.zeros(generator.sample_rate // 2) # 0.5 seconds of silence # Extract the actual base64 content if ',' in audio_data: audio_data = audio_data.split(',')[1] # Decode base64 audio data try: binary_data = base64.b64decode(audio_data) logger.debug(f"Decoded base64 data: {len(binary_data)} bytes") # Check if we have enough data for a valid WAV if len(binary_data) < 44: # WAV header is 44 bytes logger.warning("Data too small to be a valid WAV file") return torch.zeros(generator.sample_rate // 2) except Exception as e: logger.error(f"Base64 decoding error: {e}") return torch.zeros(generator.sample_rate // 2) # Multiple approaches to handle audio data audio_tensor = None sample_rate = None # Approach 1: Direct loading with torchaudio try: with BytesIO(binary_data) as temp_file: temp_file.seek(0) audio_tensor, sample_rate = torchaudio.load(temp_file, format="wav") logger.debug(f"Loaded audio: shape={audio_tensor.shape}, rate={sample_rate}Hz") # Validate tensor if audio_tensor.numel() == 0 or torch.isnan(audio_tensor).any(): raise ValueError("Invalid audio tensor") except Exception as e: logger.warning(f"Direct loading failed: {e}") # Approach 2: Using wave module and numpy try: temp_path = os.path.join(base_dir, f"temp_{time.time()}.wav") with open(temp_path, 'wb') as f: f.write(binary_data) import wave with wave.open(temp_path, 'rb') as wf: n_channels = wf.getnchannels() sample_width = wf.getsampwidth() sample_rate = wf.getframerate() n_frames = wf.getnframes() frames = wf.readframes(n_frames) # Convert to numpy array if sample_width == 2: # 16-bit audio data = np.frombuffer(frames, dtype=np.int16).astype(np.float32) / 32768.0 elif sample_width == 1: # 8-bit audio data = np.frombuffer(frames, dtype=np.uint8).astype(np.float32) / 128.0 - 1.0 else: raise ValueError(f"Unsupported sample width: {sample_width}") # Convert to mono if needed if n_channels > 1: data = data.reshape(-1, n_channels) data = data.mean(axis=1) # Convert to torch tensor audio_tensor = torch.from_numpy(data) logger.info(f"Loaded audio using wave: shape={audio_tensor.shape}") # Clean up temp file if os.path.exists(temp_path): os.remove(temp_path) except Exception as e2: logger.error(f"All audio loading methods failed: {e2}") return torch.zeros(generator.sample_rate // 2) # Format corrections if audio_tensor is None: return torch.zeros(generator.sample_rate // 2) # Ensure audio is mono if len(audio_tensor.shape) > 1 and audio_tensor.shape[0] > 1: audio_tensor = torch.mean(audio_tensor, dim=0) # Ensure 1D tensor audio_tensor = audio_tensor.squeeze() # Resample if needed if sample_rate != generator.sample_rate: try: logger.debug(f"Resampling from {sample_rate}Hz to {generator.sample_rate}Hz") resampler = torchaudio.transforms.Resample( orig_freq=sample_rate, new_freq=generator.sample_rate ) audio_tensor = resampler(audio_tensor) except Exception as e: logger.warning(f"Resampling error: {e}") # Normalize audio to avoid issues if torch.abs(audio_tensor).max() > 0: audio_tensor = audio_tensor / torch.abs(audio_tensor).max() return audio_tensor except Exception as e: logger.error(f"Unhandled error in decode_audio_data: {e}") return torch.zeros(generator.sample_rate // 2) def encode_audio_data(audio_tensor: torch.Tensor) -> str: """Encode torch tensor audio to base64 string""" try: buf = BytesIO() torchaudio.save(buf, audio_tensor.unsqueeze(0).cpu(), generator.sample_rate, format="wav") buf.seek(0) audio_base64 = base64.b64encode(buf.read()).decode('utf-8') return f"data:audio/wav;base64,{audio_base64}" except Exception as e: logger.error(f"Error encoding audio: {e}") # Return a minimal silent audio file silence = torch.zeros(generator.sample_rate // 2).unsqueeze(0) buf = BytesIO() torchaudio.save(buf, silence, generator.sample_rate, format="wav") buf.seek(0) return f"data:audio/wav;base64,{base64.b64encode(buf.read()).decode('utf-8')}" def process_speech(audio_tensor: torch.Tensor, client_id: str) -> str: """Process speech with speech recognition""" if not speech_recognizer: # Fallback to basic detection if model failed to load return detect_speech_energy(audio_tensor) try: # Save audio to temp file for Whisper temp_path = os.path.join(base_dir, f"temp_{time.time()}.wav") torchaudio.save(temp_path, audio_tensor.unsqueeze(0).cpu(), generator.sample_rate) # Perform speech recognition result = speech_recognizer(temp_path) transcription = result["text"] # Clean up temp file if os.path.exists(temp_path): os.remove(temp_path) # Return empty string if no speech detected if not transcription or transcription.isspace(): return "I didn't detect any speech. Could you please try again?" return transcription except Exception as e: logger.error(f"Speech recognition error: {e}") return "Sorry, I couldn't understand what you said. Could you try again?" def detect_speech_energy(audio_tensor: torch.Tensor) -> str: """Basic speech detection based on audio energy levels""" # Calculate audio energy energy = torch.mean(torch.abs(audio_tensor)).item() logger.debug(f"Audio energy detected: {energy:.6f}") # Generate response based on energy level if energy > 0.1: # Louder speech return "I heard you speaking clearly. How can I help you today?" elif energy > 0.05: # Moderate speech return "I heard you say something. Could you please repeat that?" elif energy > 0.02: # Soft speech return "I detected some speech, but it was quite soft. Could you speak up a bit?" else: # Very soft or no speech return "I didn't detect any speech. Could you please try again?" def generate_response(text: str, conversation_history: List[Segment]) -> str: """Generate a contextual response based on the transcribed text using Llama 3.2""" # If LLM is not available, use simple responses if llm is None: return generate_simple_response(text) try: # Create a conversational prompt based on history # Format: recent conversation turns (up to 4) + current user input history_str = "" # Add up to 4 recent conversation turns (excluding the current one) recent_segments = [ seg for seg in conversation_history[-8:] if seg.text and not seg.text.isspace() ] for i, segment in enumerate(recent_segments): speaker_name = "User" if segment.speaker == 0 else "Assistant" history_str += f"{speaker_name}: {segment.text}\n" # Construct the prompt for Llama 3.2 prompt = f"""<|system|> You are Sesame, a helpful, friendly and concise voice assistant. Keep your responses conversational, natural, and to the point. Respond to the user's latest message in the context of the conversation. <|end|> {history_str} User: {text} Assistant:""" logger.debug(f"LLM Prompt: {prompt}") # Generate response with the LLM result = llm( prompt, max_new_tokens=150, do_sample=True, temperature=0.7, top_p=0.9, repetition_penalty=1.1 ) # Extract the generated text response = result[0]["generated_text"] # Extract just the Assistant's response (after the prompt) response = response.split("Assistant:")[-1].strip() # Clean up and ensure it's not too long for TTS response = response.split("\n")[0].strip() if len(response) > 200: response = response[:197] + "..." logger.info(f"LLM response: {response}") return response except Exception as e: logger.error(f"Error generating LLM response: {e}") # Fall back to simple responses return generate_simple_response(text) def generate_simple_response(text: str) -> str: """Generate a simple rule-based response as fallback""" responses = { "hello": "Hello there! How can I help you today?", "hi": "Hi there! What can I do for you?", "how are you": "I'm doing well, thanks for asking! How about you?", "what is your name": "I'm Sesame, your voice assistant. How can I help you?", "who are you": "I'm Sesame, an AI voice assistant. I'm here to chat with you!", "bye": "Goodbye! It was nice chatting with you.", "thank you": "You're welcome! Is there anything else I can help with?", "weather": "I don't have real-time weather data, but I hope it's nice where you are!", "help": "I can chat with you using natural voice. Just speak normally and I'll respond.", "what can you do": "I can have a conversation with you, answer questions, and provide assistance with various topics.", } text_lower = text.lower() # Check for matching keywords for key, response in responses.items(): if key in text_lower: return response # Default responses based on text length if not text: return "I didn't catch that. Could you please repeat?" elif len(text) < 10: return "Thanks for your message. Could you elaborate a bit more?" else: return f"I heard you say something about that. Can you tell me more?" # Flask Routes @app.route('/') def index(): return send_from_directory(base_dir, 'index.html') @app.route('/favicon.ico') def favicon(): if os.path.exists(os.path.join(static_dir, 'favicon.ico')): return send_from_directory(static_dir, 'favicon.ico') return Response(status=204) @app.route('/voice-chat.js') def voice_chat_js(): return send_from_directory(base_dir, 'voice-chat.js') @app.route('/static/') def serve_static(path): return send_from_directory(static_dir, path) # Socket.IO Event Handlers @socketio.on('connect') def handle_connect(): client_id = request.sid logger.info(f"Client connected: {client_id}") # Initialize client context active_clients[client_id] = { 'context_segments': [], 'streaming_buffer': [], 'is_streaming': False, 'is_silence': False, 'last_active_time': time.time(), 'energy_window': deque(maxlen=10) } emit('status', {'type': 'connected', 'message': 'Connected to server'}) @socketio.on('disconnect') def handle_disconnect(): client_id = request.sid if client_id in active_clients: del active_clients[client_id] logger.info(f"Client disconnected: {client_id}") @socketio.on('generate') def handle_generate(data): client_id = request.sid if client_id not in active_clients: emit('error', {'message': 'Client not registered'}) return try: text = data.get('text', '') speaker_id = data.get('speaker', 0) logger.info(f"Generating audio for: '{text}' with speaker {speaker_id}") # Generate audio response audio_tensor = generator.generate( text=text, speaker=speaker_id, context=active_clients[client_id]['context_segments'], max_audio_length_ms=10_000, ) # Add to conversation context active_clients[client_id]['context_segments'].append( Segment(text=text, speaker=speaker_id, audio=audio_tensor) ) # Convert audio to base64 and send back to client audio_base64 = encode_audio_data(audio_tensor) emit('audio_response', { 'type': 'audio_response', 'audio': audio_base64, 'text': text }) except Exception as e: logger.error(f"Error generating audio: {e}") emit('error', { 'type': 'error', 'message': f"Error generating audio: {str(e)}" }) @socketio.on('add_to_context') def handle_add_to_context(data): client_id = request.sid if client_id not in active_clients: emit('error', {'message': 'Client not registered'}) return try: text = data.get('text', '') speaker_id = data.get('speaker', 0) audio_data = data.get('audio', '') # Convert received audio to tensor audio_tensor = decode_audio_data(audio_data) # Add to conversation context active_clients[client_id]['context_segments'].append( Segment(text=text, speaker=speaker_id, audio=audio_tensor) ) emit('context_updated', { 'type': 'context_updated', 'message': 'Audio added to context' }) except Exception as e: logger.error(f"Error adding to context: {e}") emit('error', { 'type': 'error', 'message': f"Error processing audio: {str(e)}" }) @socketio.on('clear_context') def handle_clear_context(): client_id = request.sid if client_id in active_clients: active_clients[client_id]['context_segments'] = [] emit('context_updated', { 'type': 'context_updated', 'message': 'Context cleared' }) @socketio.on('stream_audio') def handle_stream_audio(data): client_id = request.sid if client_id not in active_clients: emit('error', {'message': 'Client not registered'}) return client = active_clients[client_id] try: speaker_id = data.get('speaker', 0) audio_data = data.get('audio', '') # Skip if no audio data (might be just a connection test) if not audio_data: logger.debug("Empty audio data received, ignoring") return # Convert received audio to tensor audio_chunk = decode_audio_data(audio_data) # Start streaming mode if not already started if not client['is_streaming']: client['is_streaming'] = True client['streaming_buffer'] = [] client['energy_window'].clear() client['is_silence'] = False client['last_active_time'] = time.time() logger.info(f"[{client_id[:8]}] Streaming started with speaker ID: {speaker_id}") emit('streaming_status', { 'type': 'streaming_status', 'status': 'started' }) # Calculate audio energy for silence detection chunk_energy = torch.mean(torch.abs(audio_chunk)).item() client['energy_window'].append(chunk_energy) avg_energy = sum(client['energy_window']) / len(client['energy_window']) # Check if audio is silent current_silence = avg_energy < SILENCE_THRESHOLD # Track silence transition if not client['is_silence'] and current_silence: # Transition to silence client['is_silence'] = True client['last_active_time'] = time.time() elif client['is_silence'] and not current_silence: # User started talking again client['is_silence'] = False # Add chunk to buffer regardless of silence state client['streaming_buffer'].append(audio_chunk) # Check if silence has persisted long enough to consider "stopped talking" silence_elapsed = time.time() - client['last_active_time'] if client['is_silence'] and silence_elapsed >= SILENCE_DURATION_SEC and len(client['streaming_buffer']) > 0: # User has stopped talking - process the collected audio logger.info(f"[{client_id[:8]}] Processing audio after {silence_elapsed:.2f}s of silence") process_complete_utterance(client_id, client, speaker_id) # If buffer gets too large without silence, process it anyway elif len(client['streaming_buffer']) >= MAX_BUFFER_SIZE: logger.info(f"[{client_id[:8]}] Processing long audio segment without silence") process_complete_utterance(client_id, client, speaker_id, is_incomplete=True) # Keep half of the buffer for context (sliding window approach) half_point = len(client['streaming_buffer']) // 2 client['streaming_buffer'] = client['streaming_buffer'][half_point:] except Exception as e: import traceback traceback.print_exc() logger.error(f"Error processing streaming audio: {e}") emit('error', { 'type': 'error', 'message': f"Error processing streaming audio: {str(e)}" }) def process_complete_utterance(client_id, client, speaker_id, is_incomplete=False): """Process a complete utterance (after silence or buffer limit)""" try: # Combine audio chunks full_audio = torch.cat(client['streaming_buffer'], dim=0) # Process audio to generate a response (no speech recognition) generated_text = process_speech(full_audio, client_id) # Add suffix for incomplete utterances if is_incomplete: generated_text += " (processing continued speech...)" # Log the generated text logger.info(f"[{client_id[:8]}] Generated text: '{generated_text}'") # Handle the result if generated_text: # Add user message to context user_segment = Segment(text=generated_text, speaker=speaker_id, audio=full_audio) client['context_segments'].append(user_segment) # Send the text to client emit('transcription', { 'type': 'transcription', 'text': generated_text }, room=client_id) # Only generate a response if this is a complete utterance if not is_incomplete: # Generate a contextual response response_text = generate_response(generated_text, client['context_segments']) logger.info(f"[{client_id[:8]}] Generating response: '{response_text}'") # Let the client know we're processing emit('processing_status', { 'type': 'processing_status', 'status': 'generating_audio', 'message': 'Generating audio response...' }, room=client_id) # Generate audio for the response try: # Use a different speaker than the user ai_speaker_id = 1 if speaker_id == 0 else 0 # Generate the full response audio_tensor = generator.generate( text=response_text, speaker=ai_speaker_id, context=client['context_segments'], max_audio_length_ms=10_000, ) # Add response to context ai_segment = Segment( text=response_text, speaker=ai_speaker_id, audio=audio_tensor ) client['context_segments'].append(ai_segment) # Convert audio to base64 and send back to client audio_base64 = encode_audio_data(audio_tensor) emit('audio_response', { 'type': 'audio_response', 'text': response_text, 'audio': audio_base64 }, room=client_id) logger.info(f"[{client_id[:8]}] Audio response sent") except Exception as e: logger.error(f"Error generating audio response: {e}") emit('error', { 'type': 'error', 'message': "Sorry, there was an error generating the audio response." }, room=client_id) else: # If processing failed, send a notification emit('error', { 'type': 'error', 'message': "Sorry, I couldn't understand what you said. Could you try again?" }, room=client_id) # Only clear buffer for complete utterances if not is_incomplete: # Reset state client['streaming_buffer'] = [] client['energy_window'].clear() client['is_silence'] = False client['last_active_time'] = time.time() except Exception as e: logger.error(f"Error processing utterance: {e}") emit('error', { 'type': 'error', 'message': f"Error processing audio: {str(e)}" }, room=client_id) @socketio.on('stop_streaming') def handle_stop_streaming(data): client_id = request.sid if client_id not in active_clients: return client = active_clients[client_id] client['is_streaming'] = False if client['streaming_buffer'] and len(client['streaming_buffer']) > 5: # Process any remaining audio in the buffer logger.info(f"[{client_id[:8]}] Processing final audio buffer on stop") process_complete_utterance(client_id, client, data.get("speaker", 0)) client['streaming_buffer'] = [] emit('streaming_status', { 'type': 'streaming_status', 'status': 'stopped' }) def stream_audio_to_client(client_id, audio_tensor, text, speaker_id, chunk_size_ms=CHUNK_SIZE_MS): """Stream audio to client in chunks to simulate real-time generation""" try: if client_id not in active_clients: logger.warning(f"Client {client_id} not found for streaming") return # Calculate chunk size in samples chunk_size = int(generator.sample_rate * chunk_size_ms / 1000) total_chunks = math.ceil(audio_tensor.size(0) / chunk_size) logger.info(f"Streaming audio in {total_chunks} chunks of {chunk_size_ms}ms each") # Send initial response with text but no audio yet socketio.emit('audio_response_start', { 'type': 'audio_response_start', 'text': text, 'total_chunks': total_chunks }, room=client_id) # Stream each chunk for i in range(total_chunks): start_idx = i * chunk_size end_idx = min(start_idx + chunk_size, audio_tensor.size(0)) # Extract chunk chunk = audio_tensor[start_idx:end_idx] # Encode chunk chunk_base64 = encode_audio_data(chunk) # Send chunk socketio.emit('audio_response_chunk', { 'type': 'audio_response_chunk', 'chunk_index': i, 'total_chunks': total_chunks, 'audio': chunk_base64, 'is_last': i == total_chunks - 1 }, room=client_id) # Brief pause between chunks to simulate streaming time.sleep(0.1) # Send completion message socketio.emit('audio_response_complete', { 'type': 'audio_response_complete', 'text': text }, room=client_id) logger.info(f"Audio streaming complete: {total_chunks} chunks sent") except Exception as e: logger.error(f"Error streaming audio to client: {e}") import traceback traceback.print_exc() # Main server start if __name__ == "__main__": print(f"\n{'='*60}") print(f"🔊 Sesame AI Voice Chat Server") print(f"{'='*60}") print(f"📡 Server Information:") print(f" - Local URL: http://localhost:5000") print(f" - Network URL: http://:5000") print(f"{'='*60}") print(f"🌐 Device: {device.upper()}") print(f"🧠 Models: Sesame CSM (TTS only)") print(f"🔧 Serving from: {os.path.join(base_dir, 'index.html')}") print(f"{'='*60}") print(f"Ready to receive connections! Press Ctrl+C to stop the server.\n") socketio.run(app, host="0.0.0.0", port=5000, debug=False)