diff --git a/Backend/server.py b/Backend/server.py index 8ba56b4..8f4e278 100644 --- a/Backend/server.py +++ b/Backend/server.py @@ -8,7 +8,6 @@ import logging import numpy as np import torch import torchaudio -import whisperx from io import BytesIO from typing import List, Dict, Any, Optional from flask import Flask, request, send_from_directory, Response @@ -25,68 +24,24 @@ logging.basicConfig( ) logger = logging.getLogger("sesame-server") -# CUDA Environment Setup -def setup_cuda_environment(): - """Set up CUDA environment with proper error handling""" - # Search for CUDA libraries in common locations - cuda_lib_dirs = [ - "/usr/local/cuda/lib64", - "/usr/lib/x86_64-linux-gnu", - "/usr/local/cuda/extras/CUPTI/lib64" - ] - - # Add directories to LD_LIBRARY_PATH if they exist - current_ld_path = os.environ.get('LD_LIBRARY_PATH', '') - for cuda_dir in cuda_lib_dirs: - if os.path.exists(cuda_dir) and cuda_dir not in current_ld_path: - if current_ld_path: - os.environ['LD_LIBRARY_PATH'] = f"{current_ld_path}:{cuda_dir}" - else: - os.environ['LD_LIBRARY_PATH'] = cuda_dir - current_ld_path = os.environ['LD_LIBRARY_PATH'] - - logger.info(f"LD_LIBRARY_PATH set to: {os.environ.get('LD_LIBRARY_PATH', 'not set')}") - - # Determine best compute device - device = "cpu" - compute_type = "int8" - +# Determine best compute device +if torch.backends.mps.is_available(): + device = "mps" +elif torch.cuda.is_available(): try: - # Set CUDA preferences - os.environ["CUDA_VISIBLE_DEVICES"] = "0" # Limit to first GPU only - - # Try enabling TF32 precision if available - try: - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True - torch.backends.cudnn.enabled = True - torch.backends.cudnn.benchmark = True - except Exception as e: - logger.warning(f"Could not set advanced CUDA options: {e}") - - # Test if CUDA is functional - if torch.cuda.is_available(): - try: - # Test basic CUDA operations - x = torch.rand(10, device="cuda") - y = x + x - del x, y - torch.cuda.empty_cache() - device = "cuda" - compute_type = "float16" - logger.info("CUDA is fully functional") - except Exception as e: - logger.warning(f"CUDA available but not working correctly: {e}") - device = "cpu" - else: - logger.info("CUDA is not available, using CPU") + # Test CUDA functionality + torch.rand(10, device="cuda") + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + torch.backends.cudnn.benchmark = True + device = "cuda" + logger.info("CUDA is fully functional") except Exception as e: - logger.error(f"Error setting up computing environment: {e}") - - return device, compute_type - -# Set up the compute environment -device, compute_type = setup_cuda_environment() + logger.warning(f"CUDA available but not working correctly: {e}") + device = "cpu" +else: + device = "cpu" + logger.info("Using CPU") # Constants and Configuration SILENCE_THRESHOLD = 0.01 @@ -99,9 +54,37 @@ base_dir = os.path.dirname(os.path.abspath(__file__)) static_dir = os.path.join(base_dir, "static") os.makedirs(static_dir, exist_ok=True) +# Define a simple energy-based speech detector +class SpeechDetector: + def __init__(self): + self.min_speech_energy = 0.01 + self.speech_window = 0.2 # seconds + + def detect_speech(self, audio_tensor, sample_rate): + # Calculate frame size based on window size + frame_size = int(sample_rate * self.speech_window) + + # If audio is shorter than frame size, use the entire audio + if audio_tensor.shape[0] < frame_size: + frames = [audio_tensor] + else: + # Split audio into frames + frames = [audio_tensor[i:i+frame_size] for i in range(0, len(audio_tensor), frame_size)] + + # Calculate energy per frame + energies = [torch.mean(frame**2).item() for frame in frames] + + # Determine if there's speech based on energy threshold + has_speech = any(e > self.min_speech_energy for e in energies) + + return has_speech + +speech_detector = SpeechDetector() +logger.info("Initialized simple speech detector") + # Model Loading Functions def load_speech_models(): - """Load all required speech models with fallbacks""" + """Load speech generation model""" # Load speech generation model (Sesame CSM) try: logger.info(f"Loading Sesame CSM model on {device}...") @@ -120,52 +103,10 @@ def load_speech_models(): else: raise RuntimeError("Failed to load speech synthesis model on any device") - # Load ASR model (WhisperX) - try: - logger.info("Loading WhisperX model...") - # Start with the tiny model on CPU for reliable initialization - asr_model = whisperx.load_model("tiny", "cpu", compute_type="int8") - logger.info("WhisperX 'tiny' model loaded on CPU successfully") - - # Try upgrading to GPU if available - if device == "cuda": - try: - logger.info("Trying to load WhisperX on CUDA...") - # Test with a tiny model first - test_audio = torch.zeros(16000) # 1 second of silence - - cuda_model = whisperx.load_model("tiny", "cuda", compute_type="float16") - # Test the model with real inference - _ = cuda_model.transcribe(test_audio.numpy(), batch_size=1) - asr_model = cuda_model - logger.info("WhisperX model running on CUDA successfully") - - # Try to upgrade to small model - try: - small_model = whisperx.load_model("small", "cuda", compute_type="float16") - _ = small_model.transcribe(test_audio.numpy(), batch_size=1) - asr_model = small_model - logger.info("WhisperX 'small' model loaded on CUDA successfully") - except Exception as e: - logger.warning(f"Staying with 'tiny' model on CUDA: {e}") - except Exception as e: - logger.warning(f"CUDA loading failed, staying with CPU model: {e}") - except Exception as e: - logger.error(f"Error loading WhisperX model: {e}") - # Create a minimal dummy model as last resort - class DummyModel: - def __init__(self): - self.device = "cpu" - def transcribe(self, *args, **kwargs): - return {"segments": [{"text": "Speech recognition currently unavailable."}]} - - asr_model = DummyModel() - logger.warning("Using dummy transcription model - ASR functionality limited") - - return generator, asr_model + return generator -# Load speech models -generator, asr_model = load_speech_models() +# Load speech model +generator = load_speech_models() # Set up Flask and Socket.IO app = Flask(__name__) @@ -307,63 +248,23 @@ def encode_audio_data(audio_tensor: torch.Tensor) -> str: buf.seek(0) return f"data:audio/wav;base64,{base64.b64encode(buf.read()).decode('utf-8')}" -def transcribe_audio(audio_tensor: torch.Tensor) -> str: - """Transcribe audio using WhisperX with robust error handling""" - global asr_model +def process_speech(audio_tensor: torch.Tensor, client_id: str) -> str: + """Process speech and return a simple response""" + # In this simplified version, we'll just check if there's sound + # and provide basic responses instead of doing actual speech recognition - try: - # Save the tensor to a temporary file - temp_path = os.path.join(base_dir, f"temp_audio_{time.time()}.wav") - torchaudio.save(temp_path, audio_tensor.unsqueeze(0).cpu(), generator.sample_rate) + if speech_detector and speech_detector.detect_speech(audio_tensor, generator.sample_rate): + # Generate a response based on audio energy + energy = torch.mean(torch.abs(audio_tensor)).item() - logger.info(f"Transcribing audio file: {os.path.getsize(temp_path)} bytes") - - # Load the audio for WhisperX - try: - audio = whisperx.load_audio(temp_path) - except Exception as e: - logger.warning(f"WhisperX load_audio failed: {e}") - # Fall back to manual loading - import soundfile as sf - audio, sr = sf.read(temp_path) - if sr != 16000: # WhisperX expects 16kHz audio - from scipy import signal - audio = signal.resample(audio, int(len(audio) * 16000 / sr)) - - # Transcribe with error handling - try: - result = asr_model.transcribe(audio, batch_size=4) - except RuntimeError as e: - if "CUDA" in str(e) or "libcudnn" in str(e): - logger.warning(f"CUDA error in transcription, falling back to CPU: {e}") - try: - # Try CPU model - cpu_model = whisperx.load_model("tiny", "cpu", compute_type="int8") - result = cpu_model.transcribe(audio, batch_size=1) - # Update the global model if the original one is broken - asr_model = cpu_model - except Exception as cpu_e: - logger.error(f"CPU fallback failed: {cpu_e}") - return "I'm having trouble processing audio right now." - else: - raise - finally: - # Clean up - if os.path.exists(temp_path): - os.remove(temp_path) - - # Extract text from segments - if result["segments"] and len(result["segments"]) > 0: - transcription = " ".join([segment["text"] for segment in result["segments"]]) - logger.info(f"Transcription: '{transcription.strip()}'") - return transcription.strip() - - return "" - except Exception as e: - logger.error(f"Error in transcription: {e}") - if os.path.exists(temp_path): - os.remove(temp_path) - return "I heard something but couldn't understand it." + if energy > 0.1: # Louder speech + return "I heard you speaking clearly. How can I help you today?" + elif energy > 0.05: # Moderate speech + return "I heard you say something. Could you please repeat that?" + else: # Soft speech + return "I detected some speech, but it was quite soft. Could you speak up a bit?" + else: + return "I didn't detect any speech. Could you please try again?" def generate_response(text: str, conversation_history: List[Segment]) -> str: """Generate a contextual response based on the transcribed text""" @@ -394,7 +295,7 @@ def generate_response(text: str, conversation_history: List[Segment]) -> str: elif len(text) < 10: return "Thanks for your message. Could you elaborate a bit more?" else: - return f"I understand you said '{text}'. That's interesting! Can you tell me more about that?" + return f"I heard you speaking. That's interesting! Can you tell me more about that?" # Flask Routes @app.route('/') @@ -610,33 +511,32 @@ def process_complete_utterance(client_id, client, speaker_id, is_incomplete=Fals # Combine audio chunks full_audio = torch.cat(client['streaming_buffer'], dim=0) - # Process with speech-to-text - logger.info(f"[{client_id[:8]}] Starting transcription...") - transcribed_text = transcribe_audio(full_audio) + # Process audio to generate a response (no speech recognition) + generated_text = process_speech(full_audio, client_id) # Add suffix for incomplete utterances if is_incomplete: - transcribed_text += " (processing continued speech...)" + generated_text += " (processing continued speech...)" - # Log the transcription - logger.info(f"[{client_id[:8]}] Transcribed: '{transcribed_text}'") + # Log the generated text + logger.info(f"[{client_id[:8]}] Generated text: '{generated_text}'") - # Handle the transcription result - if transcribed_text: + # Handle the result + if generated_text: # Add user message to context - user_segment = Segment(text=transcribed_text, speaker=speaker_id, audio=full_audio) + user_segment = Segment(text=generated_text, speaker=speaker_id, audio=full_audio) client['context_segments'].append(user_segment) - # Send the transcribed text to client + # Send the text to client emit('transcription', { 'type': 'transcription', - 'text': transcribed_text + 'text': generated_text }, room=client_id) # Only generate a response if this is a complete utterance if not is_incomplete: # Generate a contextual response - response_text = generate_response(transcribed_text, client['context_segments']) + response_text = generate_response(generated_text, client['context_segments']) logger.info(f"[{client_id[:8]}] Generating response: '{response_text}'") # Let the client know we're processing @@ -684,7 +584,7 @@ def process_complete_utterance(client_id, client, speaker_id, is_incomplete=Fals 'message': "Sorry, there was an error generating the audio response." }, room=client_id) else: - # If transcription failed, send a notification + # If processing failed, send a notification emit('error', { 'type': 'error', 'message': "Sorry, I couldn't understand what you said. Could you try again?" @@ -791,7 +691,7 @@ if __name__ == "__main__": print(f" - Network URL: http://:5000") print(f"{'='*60}") print(f"🌐 Device: {device.upper()}") - print(f"🧠 Models: Sesame CSM + WhisperX ASR") + print(f"🧠 Models: Sesame CSM (TTS only)") print(f"🔧 Serving from: {os.path.join(base_dir, 'index.html')}") print(f"{'='*60}") print(f"Ready to receive connections! Press Ctrl+C to stop the server.\n")