From 6152e300c000793d3f5682dda2ac1431fc03a12e Mon Sep 17 00:00:00 2001 From: GamerBoss101 Date: Sun, 30 Mar 2025 00:24:26 -0400 Subject: [PATCH] Demo Update 6 --- Backend/server.py | 659 +++++++++++++++++++++++----------------------- 1 file changed, 335 insertions(+), 324 deletions(-) diff --git a/Backend/server.py b/Backend/server.py index d0dee80..8ba56b4 100644 --- a/Backend/server.py +++ b/Backend/server.py @@ -1,9 +1,13 @@ import os import base64 import json +import time +import math +import gc +import logging +import numpy as np import torch import torchaudio -import numpy as np import whisperx from io import BytesIO from typing import List, Dict, Any, Optional @@ -11,290 +15,314 @@ from flask import Flask, request, send_from_directory, Response from flask_cors import CORS from flask_socketio import SocketIO, emit, disconnect from generator import load_csm_1b, Segment -import time -import gc from collections import deque from threading import Lock -# Add this at the top of your file, replacing your current CUDA setup +# Configure logging +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' +) +logger = logging.getLogger("sesame-server") -# CUDA setup with robust error handling -try: - # Handle CUDA issues - os.environ["CUDA_VISIBLE_DEVICES"] = "0" # Limit to first GPU only +# CUDA Environment Setup +def setup_cuda_environment(): + """Set up CUDA environment with proper error handling""" + # Search for CUDA libraries in common locations + cuda_lib_dirs = [ + "/usr/local/cuda/lib64", + "/usr/lib/x86_64-linux-gnu", + "/usr/local/cuda/extras/CUPTI/lib64" + ] - # Try enabling TF32 precision - try: - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True - except: - pass # Ignore if not supported + # Add directories to LD_LIBRARY_PATH if they exist + current_ld_path = os.environ.get('LD_LIBRARY_PATH', '') + for cuda_dir in cuda_lib_dirs: + if os.path.exists(cuda_dir) and cuda_dir not in current_ld_path: + if current_ld_path: + os.environ['LD_LIBRARY_PATH'] = f"{current_ld_path}:{cuda_dir}" + else: + os.environ['LD_LIBRARY_PATH'] = cuda_dir + current_ld_path = os.environ['LD_LIBRARY_PATH'] - # Check if CUDA is available - if torch.cuda.is_available(): - try: - # Test CUDA functionality - x = torch.rand(10, device="cuda") - y = x + x - del x, y - device = "cuda" - compute_type = "float16" - print("CUDA is fully functional") - except Exception as cuda_error: - print(f"CUDA is available but not working correctly: {str(cuda_error)}") - device = "cpu" - compute_type = "int8" - else: - device = "cpu" - compute_type = "int8" -except Exception as e: - print(f"Error setting up CUDA: {str(e)}") + logger.info(f"LD_LIBRARY_PATH set to: {os.environ.get('LD_LIBRARY_PATH', 'not set')}") + + # Determine best compute device device = "cpu" compute_type = "int8" - -print(f"Using device: {device} with compute type: {compute_type}") - -# Initialize the Sesame CSM model with robust error handling -try: - print(f"Loading Sesame CSM model on {device}...") - generator = load_csm_1b(device=device) - print("Sesame CSM model loaded successfully") -except Exception as model_error: - print(f"Error loading Sesame CSM on {device}: {str(model_error)}") - if device == "cuda": - # Try on CPU as fallback - try: - print("Trying to load Sesame CSM on CPU instead...") - device = "cpu" # Update global device setting - generator = load_csm_1b(device="cpu") - print("Sesame CSM model loaded on CPU successfully") - except Exception as cpu_error: - print(f"Fatal error - could not load Sesame CSM model: {str(cpu_error)}") - raise RuntimeError("Failed to load speech synthesis model") - else: - # Already tried CPU and it failed - raise RuntimeError("Failed to load speech synthesis model on any device") - -# Replace the WhisperX model loading section - -# Initialize WhisperX for ASR with robust error handling -print("Loading WhisperX model...") -asr_model = None # Initialize to None first to avoid scope issues - -try: - # Always start with the tiny model on CPU for stability - asr_model = whisperx.load_model("tiny", "cpu", compute_type="int8") - print("WhisperX 'tiny' model loaded on CPU successfully") - # If CPU works, try CUDA if available - if device == "cuda": + try: + # Set CUDA preferences + os.environ["CUDA_VISIBLE_DEVICES"] = "0" # Limit to first GPU only + + # Try enabling TF32 precision if available try: - print("Trying to load WhisperX on CUDA...") - cuda_model = whisperx.load_model("tiny", "cuda", compute_type="float16") - # Test the model to ensure it works - test_audio = torch.zeros(16000) # 1 second of silence at 16kHz - _ = cuda_model.transcribe(test_audio.numpy(), batch_size=1) - # If we get here, CUDA works - asr_model = cuda_model - print("WhisperX model moved to CUDA successfully") - - # Try to upgrade to small model on CUDA + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + torch.backends.cudnn.enabled = True + torch.backends.cudnn.benchmark = True + except Exception as e: + logger.warning(f"Could not set advanced CUDA options: {e}") + + # Test if CUDA is functional + if torch.cuda.is_available(): try: - small_model = whisperx.load_model("small", "cuda", compute_type="float16") - # Test it - _ = small_model.transcribe(test_audio.numpy(), batch_size=1) - asr_model = small_model - print("WhisperX 'small' model loaded on CUDA successfully") - except Exception as upgrade_error: - print(f"Staying with 'tiny' model on CUDA: {str(upgrade_error)}") - except Exception as cuda_error: - print(f"CUDA loading failed, staying with CPU model: {str(cuda_error)}") -except Exception as e: - print(f"Error loading WhisperX model: {str(e)}") - # Create a minimal dummy model as last resort - class DummyModel: - def __init__(self): - self.device = "cpu" - def transcribe(self, *args, **kwargs): - return {"segments": [{"text": "Speech recognition currently unavailable."}]} + # Test basic CUDA operations + x = torch.rand(10, device="cuda") + y = x + x + del x, y + torch.cuda.empty_cache() + device = "cuda" + compute_type = "float16" + logger.info("CUDA is fully functional") + except Exception as e: + logger.warning(f"CUDA available but not working correctly: {e}") + device = "cpu" + else: + logger.info("CUDA is not available, using CPU") + except Exception as e: + logger.error(f"Error setting up computing environment: {e}") - asr_model = DummyModel() - print("WARNING: Using dummy transcription model - ASR functionality limited") + return device, compute_type -# Silence detection parameters -SILENCE_THRESHOLD = 0.01 # Adjust based on your audio normalization -SILENCE_DURATION_SEC = 1.0 # How long silence must persist +# Set up the compute environment +device, compute_type = setup_cuda_environment() -# Define the base directory +# Constants and Configuration +SILENCE_THRESHOLD = 0.01 +SILENCE_DURATION_SEC = 0.75 +MAX_BUFFER_SIZE = 30 # Maximum chunks to buffer before processing +CHUNK_SIZE_MS = 500 # Size of audio chunks when streaming responses + +# Define the base directory and static files directory base_dir = os.path.dirname(os.path.abspath(__file__)) static_dir = os.path.join(base_dir, "static") os.makedirs(static_dir, exist_ok=True) -# Setup Flask +# Model Loading Functions +def load_speech_models(): + """Load all required speech models with fallbacks""" + # Load speech generation model (Sesame CSM) + try: + logger.info(f"Loading Sesame CSM model on {device}...") + generator = load_csm_1b(device=device) + logger.info("Sesame CSM model loaded successfully") + except Exception as e: + logger.error(f"Error loading Sesame CSM on {device}: {e}") + if device == "cuda": + try: + logger.info("Trying to load Sesame CSM on CPU instead...") + generator = load_csm_1b(device="cpu") + logger.info("Sesame CSM model loaded on CPU successfully") + except Exception as cpu_error: + logger.critical(f"Failed to load speech synthesis model: {cpu_error}") + raise RuntimeError("Failed to load speech synthesis model") + else: + raise RuntimeError("Failed to load speech synthesis model on any device") + + # Load ASR model (WhisperX) + try: + logger.info("Loading WhisperX model...") + # Start with the tiny model on CPU for reliable initialization + asr_model = whisperx.load_model("tiny", "cpu", compute_type="int8") + logger.info("WhisperX 'tiny' model loaded on CPU successfully") + + # Try upgrading to GPU if available + if device == "cuda": + try: + logger.info("Trying to load WhisperX on CUDA...") + # Test with a tiny model first + test_audio = torch.zeros(16000) # 1 second of silence + + cuda_model = whisperx.load_model("tiny", "cuda", compute_type="float16") + # Test the model with real inference + _ = cuda_model.transcribe(test_audio.numpy(), batch_size=1) + asr_model = cuda_model + logger.info("WhisperX model running on CUDA successfully") + + # Try to upgrade to small model + try: + small_model = whisperx.load_model("small", "cuda", compute_type="float16") + _ = small_model.transcribe(test_audio.numpy(), batch_size=1) + asr_model = small_model + logger.info("WhisperX 'small' model loaded on CUDA successfully") + except Exception as e: + logger.warning(f"Staying with 'tiny' model on CUDA: {e}") + except Exception as e: + logger.warning(f"CUDA loading failed, staying with CPU model: {e}") + except Exception as e: + logger.error(f"Error loading WhisperX model: {e}") + # Create a minimal dummy model as last resort + class DummyModel: + def __init__(self): + self.device = "cpu" + def transcribe(self, *args, **kwargs): + return {"segments": [{"text": "Speech recognition currently unavailable."}]} + + asr_model = DummyModel() + logger.warning("Using dummy transcription model - ASR functionality limited") + + return generator, asr_model + +# Load speech models +generator, asr_model = load_speech_models() + +# Set up Flask and Socket.IO app = Flask(__name__) CORS(app) socketio = SocketIO(app, cors_allowed_origins="*", async_mode='eventlet') # Socket connection management -thread = None thread_lock = Lock() active_clients = {} # Map client_id to client context -# Helper function to convert audio data +# Audio Utility Functions def decode_audio_data(audio_data: str) -> torch.Tensor: """Decode base64 audio data to a torch tensor with improved error handling""" try: # Skip empty audio data if not audio_data or len(audio_data) < 100: - print("Empty or too short audio data received") + logger.warning("Empty or too short audio data received") return torch.zeros(generator.sample_rate // 2) # 0.5 seconds of silence # Extract the actual base64 content if ',' in audio_data: - # Handle data URL format (data:audio/wav;base64,...) audio_data = audio_data.split(',')[1] # Decode base64 audio data try: binary_data = base64.b64decode(audio_data) - print(f"Decoded base64 data: {len(binary_data)} bytes") + logger.debug(f"Decoded base64 data: {len(binary_data)} bytes") # Check if we have enough data for a valid WAV if len(binary_data) < 44: # WAV header is 44 bytes - print("Data too small to be a valid WAV file") + logger.warning("Data too small to be a valid WAV file") return torch.zeros(generator.sample_rate // 2) except Exception as e: - print(f"Base64 decoding error: {str(e)}") + logger.error(f"Base64 decoding error: {e}") return torch.zeros(generator.sample_rate // 2) - # Save for debugging - debug_path = os.path.join(base_dir, "debug_incoming.wav") - with open(debug_path, 'wb') as f: - f.write(binary_data) - print(f"Saved debug file: {debug_path}") + # Multiple approaches to handle audio data + audio_tensor = None + sample_rate = None - # Approach 1: Load directly with torchaudio + # Approach 1: Direct loading with torchaudio try: with BytesIO(binary_data) as temp_file: - temp_file.seek(0) # Ensure we're at the start of the buffer + temp_file.seek(0) audio_tensor, sample_rate = torchaudio.load(temp_file, format="wav") - print(f"Direct loading success: shape={audio_tensor.shape}, rate={sample_rate}Hz") + logger.debug(f"Loaded audio: shape={audio_tensor.shape}, rate={sample_rate}Hz") - # Check if audio is valid + # Validate tensor if audio_tensor.numel() == 0 or torch.isnan(audio_tensor).any(): - raise ValueError("Empty or invalid audio tensor detected") + raise ValueError("Invalid audio tensor") except Exception as e: - print(f"Direct loading failed: {str(e)}") + logger.warning(f"Direct loading failed: {e}") - # Approach 2: Try to fix/normalize the WAV data + # Approach 2: Using wave module and numpy try: - # Sometimes WAV headers can be malformed, attempt to fix - temp_path = os.path.join(base_dir, "temp_fixing.wav") + temp_path = os.path.join(base_dir, f"temp_{time.time()}.wav") with open(temp_path, 'wb') as f: f.write(binary_data) - # Use a simpler numpy approach as backup - import numpy as np import wave + with wave.open(temp_path, 'rb') as wf: + n_channels = wf.getnchannels() + sample_width = wf.getsampwidth() + sample_rate = wf.getframerate() + n_frames = wf.getnframes() + frames = wf.readframes(n_frames) + + # Convert to numpy array + if sample_width == 2: # 16-bit audio + data = np.frombuffer(frames, dtype=np.int16).astype(np.float32) / 32768.0 + elif sample_width == 1: # 8-bit audio + data = np.frombuffer(frames, dtype=np.uint8).astype(np.float32) / 128.0 - 1.0 + else: + raise ValueError(f"Unsupported sample width: {sample_width}") + + # Convert to mono if needed + if n_channels > 1: + data = data.reshape(-1, n_channels) + data = data.mean(axis=1) + + # Convert to torch tensor + audio_tensor = torch.from_numpy(data) + logger.info(f"Loaded audio using wave: shape={audio_tensor.shape}") - try: - with wave.open(temp_path, 'rb') as wf: - n_channels = wf.getnchannels() - sample_width = wf.getsampwidth() - sample_rate = wf.getframerate() - n_frames = wf.getnframes() - - # Read the frames - frames = wf.readframes(n_frames) - print(f"Wave reading: channels={n_channels}, rate={sample_rate}Hz, frames={n_frames}") - - # Convert to numpy and then to torch - if sample_width == 2: # 16-bit audio - data = np.frombuffer(frames, dtype=np.int16).astype(np.float32) / 32768.0 - elif sample_width == 1: # 8-bit audio - data = np.frombuffer(frames, dtype=np.uint8).astype(np.float32) / 128.0 - 1.0 - else: - raise ValueError(f"Unsupported sample width: {sample_width}") - - # Convert to mono if needed - if n_channels > 1: - data = data.reshape(-1, n_channels) - data = data.mean(axis=1) - - # Convert to torch tensor - audio_tensor = torch.from_numpy(data) - print(f"Successfully converted with numpy: shape={audio_tensor.shape}") - except Exception as wave_error: - print(f"Wave processing failed: {str(wave_error)}") - # Try with torchaudio as last resort - audio_tensor, sample_rate = torchaudio.load(temp_path, format="wav") - - # Clean up + # Clean up temp file if os.path.exists(temp_path): os.remove(temp_path) + except Exception as e2: - print(f"All WAV loading methods failed: {str(e2)}") - print("Returning silence as fallback") + logger.error(f"All audio loading methods failed: {e2}") return torch.zeros(generator.sample_rate // 2) - # Ensure audio is the right shape (mono) + # Format corrections + if audio_tensor is None: + return torch.zeros(generator.sample_rate // 2) + + # Ensure audio is mono if len(audio_tensor.shape) > 1 and audio_tensor.shape[0] > 1: audio_tensor = torch.mean(audio_tensor, dim=0) - # Ensure we have a 1D tensor + # Ensure 1D tensor audio_tensor = audio_tensor.squeeze() # Resample if needed if sample_rate != generator.sample_rate: try: - print(f"Resampling from {sample_rate}Hz to {generator.sample_rate}Hz") + logger.debug(f"Resampling from {sample_rate}Hz to {generator.sample_rate}Hz") resampler = torchaudio.transforms.Resample( orig_freq=sample_rate, new_freq=generator.sample_rate ) audio_tensor = resampler(audio_tensor) except Exception as e: - print(f"Resampling error: {str(e)}") - # If resampling fails, just return the original audio - # The model can often handle different sample rates + logger.warning(f"Resampling error: {e}") # Normalize audio to avoid issues if torch.abs(audio_tensor).max() > 0: audio_tensor = audio_tensor / torch.abs(audio_tensor).max() - print(f"Final audio tensor: shape={audio_tensor.shape}, min={audio_tensor.min().item():.4f}, max={audio_tensor.max().item():.4f}") return audio_tensor except Exception as e: - print(f"Unhandled error in decode_audio_data: {str(e)}") - # Return a small silent audio segment as fallback - return torch.zeros(generator.sample_rate // 2) # 0.5 seconds of silence - + logger.error(f"Unhandled error in decode_audio_data: {e}") + return torch.zeros(generator.sample_rate // 2) def encode_audio_data(audio_tensor: torch.Tensor) -> str: """Encode torch tensor audio to base64 string""" - buf = BytesIO() - torchaudio.save(buf, audio_tensor.unsqueeze(0).cpu(), generator.sample_rate, format="wav") - buf.seek(0) - audio_base64 = base64.b64encode(buf.read()).decode('utf-8') - return f"data:audio/wav;base64,{audio_base64}" - + try: + buf = BytesIO() + torchaudio.save(buf, audio_tensor.unsqueeze(0).cpu(), generator.sample_rate, format="wav") + buf.seek(0) + audio_base64 = base64.b64encode(buf.read()).decode('utf-8') + return f"data:audio/wav;base64,{audio_base64}" + except Exception as e: + logger.error(f"Error encoding audio: {e}") + # Return a minimal silent audio file + silence = torch.zeros(generator.sample_rate // 2).unsqueeze(0) + buf = BytesIO() + torchaudio.save(buf, silence, generator.sample_rate, format="wav") + buf.seek(0) + return f"data:audio/wav;base64,{base64.b64encode(buf.read()).decode('utf-8')}" def transcribe_audio(audio_tensor: torch.Tensor) -> str: """Transcribe audio using WhisperX with robust error handling""" - global asr_model # Declare global at the beginning of the function + global asr_model try: # Save the tensor to a temporary file - temp_path = os.path.join(base_dir, "temp_audio.wav") + temp_path = os.path.join(base_dir, f"temp_audio_{time.time()}.wav") torchaudio.save(temp_path, audio_tensor.unsqueeze(0).cpu(), generator.sample_rate) - print(f"Transcribing audio file: {temp_path} (size: {os.path.getsize(temp_path)} bytes)") + logger.info(f"Transcribing audio file: {os.path.getsize(temp_path)} bytes") - # Load the audio file using whisperx's function + # Load the audio for WhisperX try: audio = whisperx.load_audio(temp_path) - except Exception as audio_load_error: - print(f"WhisperX load_audio failed: {str(audio_load_error)}") + except Exception as e: + logger.warning(f"WhisperX load_audio failed: {e}") # Fall back to manual loading import soundfile as sf audio, sr = sf.read(temp_path) @@ -302,59 +330,55 @@ def transcribe_audio(audio_tensor: torch.Tensor) -> str: from scipy import signal audio = signal.resample(audio, int(len(audio) * 16000 / sr)) - # Transcribe with error handling for CUDA issues + # Transcribe with error handling try: - # Try with original device - result = asr_model.transcribe(audio, batch_size=8) - except RuntimeError as cuda_error: - if "CUDA" in str(cuda_error) or "libcudnn" in str(cuda_error): - print(f"CUDA error in transcription, falling back to CPU: {str(cuda_error)}") - - # Try to load a CPU model as fallback + result = asr_model.transcribe(audio, batch_size=4) + except RuntimeError as e: + if "CUDA" in str(e) or "libcudnn" in str(e): + logger.warning(f"CUDA error in transcription, falling back to CPU: {e}") try: - # Move model to CPU and try again - asr_model = whisperx.load_model("tiny", "cpu", compute_type="int8") - result = asr_model.transcribe(audio, batch_size=1) - except Exception as e: - print(f"CPU fallback also failed: {str(e)}") + # Try CPU model + cpu_model = whisperx.load_model("tiny", "cpu", compute_type="int8") + result = cpu_model.transcribe(audio, batch_size=1) + # Update the global model if the original one is broken + asr_model = cpu_model + except Exception as cpu_e: + logger.error(f"CPU fallback failed: {cpu_e}") return "I'm having trouble processing audio right now." else: - # Re-raise if it's not a CUDA error raise + finally: + # Clean up + if os.path.exists(temp_path): + os.remove(temp_path) - # Clean up + # Extract text from segments + if result["segments"] and len(result["segments"]) > 0: + transcription = " ".join([segment["text"] for segment in result["segments"]]) + logger.info(f"Transcription: '{transcription.strip()}'") + return transcription.strip() + + return "" + except Exception as e: + logger.error(f"Error in transcription: {e}") if os.path.exists(temp_path): os.remove(temp_path) - - # Get the transcription text - if result["segments"] and len(result["segments"]) > 0: - # Combine all segments - transcription = " ".join([segment["text"] for segment in result["segments"]]) - print(f"Transcription successful: '{transcription.strip()}'") - return transcription.strip() - else: - print("Transcription returned no segments") - return "" - except Exception as e: - print(f"Error in transcription: {str(e)}") - import traceback - traceback.print_exc() - if os.path.exists("temp_audio.wav"): - os.remove("temp_audio.wav") return "I heard something but couldn't understand it." - def generate_response(text: str, conversation_history: List[Segment]) -> str: """Generate a contextual response based on the transcribed text""" - # Simple response logic - can be replaced with a more sophisticated LLM in the future + # Simple response logic - can be replaced with a more sophisticated LLM responses = { - "hello": "Hello there! How are you doing today?", + "hello": "Hello there! How can I help you today?", + "hi": "Hi there! What can I do for you?", "how are you": "I'm doing well, thanks for asking! How about you?", "what is your name": "I'm Sesame, your voice assistant. How can I help you?", + "who are you": "I'm Sesame, an AI voice assistant. I'm here to chat with you!", "bye": "Goodbye! It was nice chatting with you.", "thank you": "You're welcome! Is there anything else I can help with?", "weather": "I don't have real-time weather data, but I hope it's nice where you are!", "help": "I can chat with you using natural voice. Just speak normally and I'll respond.", + "what can you do": "I can have a conversation with you, answer questions, and provide assistance with various topics.", } text_lower = text.lower() @@ -372,7 +396,7 @@ def generate_response(text: str, conversation_history: List[Segment]) -> str: else: return f"I understand you said '{text}'. That's interesting! Can you tell me more about that?" -# Flask routes for serving static content +# Flask Routes @app.route('/') def index(): return send_from_directory(base_dir, 'index.html') @@ -391,11 +415,11 @@ def voice_chat_js(): def serve_static(path): return send_from_directory(static_dir, path) -# Socket.IO event handlers +# Socket.IO Event Handlers @socketio.on('connect') def handle_connect(): client_id = request.sid - print(f"Client connected: {client_id}") + logger.info(f"Client connected: {client_id}") # Initialize client context active_clients[client_id] = { @@ -414,7 +438,7 @@ def handle_disconnect(): client_id = request.sid if client_id in active_clients: del active_clients[client_id] - print(f"Client disconnected: {client_id}") + logger.info(f"Client disconnected: {client_id}") @socketio.on('generate') def handle_generate(data): @@ -427,7 +451,7 @@ def handle_generate(data): text = data.get('text', '') speaker_id = data.get('speaker', 0) - print(f"Generating audio for: '{text}' with speaker {speaker_id}") + logger.info(f"Generating audio for: '{text}' with speaker {speaker_id}") # Generate audio response audio_tensor = generator.generate( @@ -446,11 +470,12 @@ def handle_generate(data): audio_base64 = encode_audio_data(audio_tensor) emit('audio_response', { 'type': 'audio_response', - 'audio': audio_base64 + 'audio': audio_base64, + 'text': text }) except Exception as e: - print(f"Error generating audio: {str(e)}") + logger.error(f"Error generating audio: {e}") emit('error', { 'type': 'error', 'message': f"Error generating audio: {str(e)}" @@ -482,7 +507,7 @@ def handle_add_to_context(data): }) except Exception as e: - print(f"Error adding to context: {str(e)}") + logger.error(f"Error adding to context: {e}") emit('error', { 'type': 'error', 'message': f"Error processing audio: {str(e)}" @@ -512,6 +537,11 @@ def handle_stream_audio(data): speaker_id = data.get('speaker', 0) audio_data = data.get('audio', '') + # Skip if no audio data (might be just a connection test) + if not audio_data: + logger.debug("Empty audio data received, ignoring") + return + # Convert received audio to tensor audio_chunk = decode_audio_data(audio_data) @@ -522,7 +552,7 @@ def handle_stream_audio(data): client['energy_window'].clear() client['is_silence'] = False client['last_active_time'] = time.time() - print(f"[{client_id}] Streaming started with speaker ID: {speaker_id}") + logger.info(f"[{client_id[:8]}] Streaming started with speaker ID: {speaker_id}") emit('streaming_status', { 'type': 'streaming_status', 'status': 'started' @@ -553,52 +583,74 @@ def handle_stream_audio(data): if client['is_silence'] and silence_elapsed >= SILENCE_DURATION_SEC and len(client['streaming_buffer']) > 0: # User has stopped talking - process the collected audio - print(f"[{client_id}] Processing audio after {silence_elapsed:.2f}s of silence") + logger.info(f"[{client_id[:8]}] Processing audio after {silence_elapsed:.2f}s of silence") + process_complete_utterance(client_id, client, speaker_id) + + # If buffer gets too large without silence, process it anyway + elif len(client['streaming_buffer']) >= MAX_BUFFER_SIZE: + logger.info(f"[{client_id[:8]}] Processing long audio segment without silence") + process_complete_utterance(client_id, client, speaker_id, is_incomplete=True) - full_audio = torch.cat(client['streaming_buffer'], dim=0) + # Keep half of the buffer for context (sliding window approach) + half_point = len(client['streaming_buffer']) // 2 + client['streaming_buffer'] = client['streaming_buffer'][half_point:] - # Process with WhisperX speech-to-text - print(f"[{client_id}] Starting transcription with WhisperX...") - transcribed_text = transcribe_audio(full_audio) + except Exception as e: + import traceback + traceback.print_exc() + logger.error(f"Error processing streaming audio: {e}") + emit('error', { + 'type': 'error', + 'message': f"Error processing streaming audio: {str(e)}" + }) + +def process_complete_utterance(client_id, client, speaker_id, is_incomplete=False): + """Process a complete utterance (after silence or buffer limit)""" + try: + # Combine audio chunks + full_audio = torch.cat(client['streaming_buffer'], dim=0) + + # Process with speech-to-text + logger.info(f"[{client_id[:8]}] Starting transcription...") + transcribed_text = transcribe_audio(full_audio) + + # Add suffix for incomplete utterances + if is_incomplete: + transcribed_text += " (processing continued speech...)" + + # Log the transcription + logger.info(f"[{client_id[:8]}] Transcribed: '{transcribed_text}'") + + # Handle the transcription result + if transcribed_text: + # Add user message to context + user_segment = Segment(text=transcribed_text, speaker=speaker_id, audio=full_audio) + client['context_segments'].append(user_segment) - # Log the transcription - print(f"[{client_id}] Transcribed text: '{transcribed_text}'") + # Send the transcribed text to client + emit('transcription', { + 'type': 'transcription', + 'text': transcribed_text + }, room=client_id) - # Handle the transcription result - if transcribed_text: - # Add user message to context - user_segment = Segment(text=transcribed_text, speaker=speaker_id, audio=full_audio) - client['context_segments'].append(user_segment) - - # Send the transcribed text to client - emit('transcription', { - 'type': 'transcription', - 'text': transcribed_text - }) - + # Only generate a response if this is a complete utterance + if not is_incomplete: # Generate a contextual response response_text = generate_response(transcribed_text, client['context_segments']) - print(f"[{client_id}] Generating audio response: '{response_text}'") + logger.info(f"[{client_id[:8]}] Generating response: '{response_text}'") # Let the client know we're processing emit('processing_status', { 'type': 'processing_status', 'status': 'generating_audio', 'message': 'Generating audio response...' - }) + }, room=client_id) # Generate audio for the response try: # Use a different speaker than the user ai_speaker_id = 1 if speaker_id == 0 else 0 - # Start audio generation with streaming (chunk by chunk) - audio_chunks = [] - - # This version tries to stream the audio generation in smaller chunks - # Note: CSM model doesn't natively support incremental generation, - # so we're simulating it here for a more responsive UI experience - # Generate the full response audio_tensor = generator.generate( text=response_text, @@ -621,60 +673,37 @@ def handle_stream_audio(data): 'type': 'audio_response', 'text': response_text, 'audio': audio_base64 - }) + }, room=client_id) - print(f"[{client_id}] Audio response sent: {len(audio_base64)} bytes") + logger.info(f"[{client_id[:8]}] Audio response sent") - except Exception as gen_error: - print(f"Error generating audio response: {str(gen_error)}") + except Exception as e: + logger.error(f"Error generating audio response: {e}") emit('error', { 'type': 'error', 'message': "Sorry, there was an error generating the audio response." - }) - else: - # If transcription failed, send a generic response - emit('error', { - 'type': 'error', - 'message': "Sorry, I couldn't understand what you said. Could you try again?" - }) - - # Clear buffer and reset silence detection + }, room=client_id) + else: + # If transcription failed, send a notification + emit('error', { + 'type': 'error', + 'message': "Sorry, I couldn't understand what you said. Could you try again?" + }, room=client_id) + + # Only clear buffer for complete utterances + if not is_incomplete: + # Reset state client['streaming_buffer'] = [] client['energy_window'].clear() client['is_silence'] = False client['last_active_time'] = time.time() - - # If buffer gets too large without silence, process it anyway - elif len(client['streaming_buffer']) >= 30: # ~6 seconds of audio at 5 chunks/sec - print(f"[{client_id}] Processing long audio segment without silence") - full_audio = torch.cat(client['streaming_buffer'], dim=0) - - # Process with WhisperX speech-to-text - transcribed_text = transcribe_audio(full_audio) - - if transcribed_text: - client['context_segments'].append( - Segment(text=transcribed_text, speaker=speaker_id, audio=full_audio) - ) - - # Send the transcribed text to client - emit('transcription', { - 'type': 'transcription', - 'text': transcribed_text + " (processing continued speech...)" - }) - - # Keep half of the buffer for context (sliding window approach) - half_point = len(client['streaming_buffer']) // 2 - client['streaming_buffer'] = client['streaming_buffer'][half_point:] except Exception as e: - import traceback - traceback.print_exc() - print(f"Error processing streaming audio: {str(e)}") + logger.error(f"Error processing utterance: {e}") emit('error', { 'type': 'error', - 'message': f"Error processing streaming audio: {str(e)}" - }) + 'message': f"Error processing audio: {str(e)}" + }, room=client_id) @socketio.on('stop_streaming') def handle_stop_streaming(data): @@ -687,21 +716,8 @@ def handle_stop_streaming(data): if client['streaming_buffer'] and len(client['streaming_buffer']) > 5: # Process any remaining audio in the buffer - full_audio = torch.cat(client['streaming_buffer'], dim=0) - - # Process with WhisperX speech-to-text - transcribed_text = transcribe_audio(full_audio) - - if transcribed_text: - client['context_segments'].append( - Segment(text=transcribed_text, speaker=data.get("speaker", 0), audio=full_audio) - ) - - # Send the transcribed text to client - emit('transcription', { - 'type': 'transcription', - 'text': transcribed_text - }) + logger.info(f"[{client_id[:8]}] Processing final audio buffer on stop") + process_complete_utterance(client_id, client, data.get("speaker", 0)) client['streaming_buffer'] = [] emit('streaming_status', { @@ -709,18 +725,18 @@ def handle_stop_streaming(data): 'status': 'stopped' }) -def stream_audio_to_client(client_id, audio_tensor, text, speaker_id, chunk_size_ms=500): +def stream_audio_to_client(client_id, audio_tensor, text, speaker_id, chunk_size_ms=CHUNK_SIZE_MS): """Stream audio to client in chunks to simulate real-time generation""" try: if client_id not in active_clients: - print(f"Client {client_id} not found for streaming") + logger.warning(f"Client {client_id} not found for streaming") return # Calculate chunk size in samples chunk_size = int(generator.sample_rate * chunk_size_ms / 1000) total_chunks = math.ceil(audio_tensor.size(0) / chunk_size) - print(f"Streaming audio in {total_chunks} chunks of {chunk_size_ms}ms each") + logger.info(f"Streaming audio in {total_chunks} chunks of {chunk_size_ms}ms each") # Send initial response with text but no audio yet socketio.emit('audio_response_start', { @@ -758,29 +774,24 @@ def stream_audio_to_client(client_id, audio_tensor, text, speaker_id, chunk_size 'text': text }, room=client_id) - print(f"Audio streaming complete: {total_chunks} chunks sent") + logger.info(f"Audio streaming complete: {total_chunks} chunks sent") except Exception as e: - print(f"Error streaming audio to client: {str(e)}") + logger.error(f"Error streaming audio to client: {e}") import traceback traceback.print_exc() +# Main server start if __name__ == "__main__": print(f"\n{'='*60}") - print(f"🔊 Sesame AI Voice Chat Server (Flask Implementation)") + print(f"🔊 Sesame AI Voice Chat Server") print(f"{'='*60}") print(f"📡 Server Information:") print(f" - Local URL: http://localhost:5000") print(f" - Network URL: http://:5000") - print(f" - WebSocket: ws://:5000/socket.io") - print(f"{'='*60}") - print(f"💡 To make this server public:") - print(f" 1. Ensure port 5000 is open in your firewall") - print(f" 2. Set up port forwarding on your router to port 5000") - print(f" 3. Or use a service like ngrok with: ngrok http 5000") print(f"{'='*60}") print(f"🌐 Device: {device.upper()}") - print(f"🧠 Models loaded: Sesame CSM + WhisperX ({asr_model.device})") + print(f"🧠 Models: Sesame CSM + WhisperX ASR") print(f"🔧 Serving from: {os.path.join(base_dir, 'index.html')}") print(f"{'='*60}") print(f"Ready to receive connections! Press Ctrl+C to stop the server.\n")