diff --git a/Backend/server.py b/Backend/server.py index b638e99..a6b70a3 100644 --- a/Backend/server.py +++ b/Backend/server.py @@ -16,56 +16,91 @@ import gc from collections import deque from threading import Lock -# Add these lines right after your imports -import torch -import os +# Add this at the top of your file, replacing your current CUDA setup -# Handle CUDA issues -os.environ["CUDA_VISIBLE_DEVICES"] = "0" # Limit to first GPU only -torch.backends.cudnn.benchmark = True - -# Set CUDA settings to avoid TF32 warnings -torch.backends.cuda.matmul.allow_tf32 = True -torch.backends.cudnn.allow_tf32 = True - -# Set compute type based on available hardware -if torch.cuda.is_available(): - device = "cuda" - compute_type = "float16" # Faster for CUDA -else: +# CUDA setup with robust error handling +try: + # Handle CUDA issues + os.environ["CUDA_VISIBLE_DEVICES"] = "0" # Limit to first GPU only + + # Try enabling TF32 precision + try: + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + except: + pass # Ignore if not supported + + # Check if CUDA is available + if torch.cuda.is_available(): + try: + # Test CUDA functionality + x = torch.rand(10, device="cuda") + y = x + x + del x, y + device = "cuda" + compute_type = "float16" + print("CUDA is fully functional") + except Exception as cuda_error: + print(f"CUDA is available but not working correctly: {str(cuda_error)}") + device = "cpu" + compute_type = "int8" + else: + device = "cpu" + compute_type = "int8" +except Exception as e: + print(f"Error setting up CUDA: {str(e)}") device = "cpu" - compute_type = "int8" # Better for CPU + compute_type = "int8" print(f"Using device: {device} with compute type: {compute_type}") -# Select device -if torch.cuda.is_available(): - device = "cuda" -else: - device = "cpu" -print(f"Using device: {device}") +# Initialize the Sesame CSM model with robust error handling +try: + print(f"Loading Sesame CSM model on {device}...") + generator = load_csm_1b(device=device) + print("Sesame CSM model loaded successfully") +except Exception as model_error: + print(f"Error loading Sesame CSM on {device}: {str(model_error)}") + if device == "cuda": + # Try on CPU as fallback + try: + print("Trying to load Sesame CSM on CPU instead...") + device = "cpu" # Update global device setting + generator = load_csm_1b(device="cpu") + print("Sesame CSM model loaded on CPU successfully") + except Exception as cpu_error: + print(f"Fatal error - could not load Sesame CSM model: {str(cpu_error)}") + raise RuntimeError("Failed to load speech synthesis model") + else: + # Already tried CPU and it failed + raise RuntimeError("Failed to load speech synthesis model on any device") -# Initialize the model -generator = load_csm_1b(device=device) - -# Initialize WhisperX for ASR +# Initialize WhisperX for ASR with robust error handling print("Loading WhisperX model...") try: - # Try to load a smaller model for faster response times - asr_model = whisperx.load_model("small", device, compute_type=compute_type) - print("WhisperX 'small' model loaded successfully") + # First try the smallest model ("tiny") to avoid memory issues + asr_model = whisperx.load_model("tiny", device, compute_type=compute_type) + print("WhisperX 'tiny' model loaded successfully") + + # If tiny worked and we have CUDA, try upgrading to small + if device == "cuda": + try: + asr_model = whisperx.load_model("small", device, compute_type=compute_type) + print("WhisperX 'small' model loaded successfully") + except Exception as upgrade_error: + print(f"Staying with 'tiny' model: {str(upgrade_error)}") except Exception as e: - print(f"Error loading 'small' model: {str(e)}") + print(f"Error loading models on {device}: {str(e)}") + print("Falling back to CPU model") try: - # Fall back to tiny model if small fails - asr_model = whisperx.load_model("tiny", device, compute_type=compute_type) - print("WhisperX 'tiny' model loaded as fallback") - except Exception as e2: - print(f"Error loading fallback model: {str(e2)}") - print("Trying CPU model as last resort") - # Last resort - try CPU + # Force CPU as last resort + device = "cpu" + compute_type = "int8" asr_model = whisperx.load_model("tiny", "cpu", compute_type="int8") print("WhisperX loaded on CPU as last resort") + except Exception as cpu_error: + print(f"Fatal error - could not load any model: {str(cpu_error)}") + raise RuntimeError("No ASR model could be loaded. Please check your CUDA installation.") # Silence detection parameters SILENCE_THRESHOLD = 0.01 # Adjust based on your audio normalization @@ -226,7 +261,7 @@ def encode_audio_data(audio_tensor: torch.Tensor) -> str: def transcribe_audio(audio_tensor: torch.Tensor) -> str: - """Transcribe audio using WhisperX""" + """Transcribe audio using WhisperX with robust error handling""" try: # Save the tensor to a temporary file temp_path = os.path.join(base_dir, "temp_audio.wav") @@ -234,9 +269,38 @@ def transcribe_audio(audio_tensor: torch.Tensor) -> str: print(f"Transcribing audio file: {temp_path} (size: {os.path.getsize(temp_path)} bytes)") - # Load and transcribe the audio - audio = whisperx.load_audio(temp_path) - result = asr_model.transcribe(audio, batch_size=16) + # Load the audio file using whisperx's function + try: + audio = whisperx.load_audio(temp_path) + except Exception as audio_load_error: + print(f"WhisperX load_audio failed: {str(audio_load_error)}") + # Fall back to manual loading + import soundfile as sf + audio, sr = sf.read(temp_path) + if sr != 16000: # WhisperX expects 16kHz audio + from scipy import signal + audio = signal.resample(audio, int(len(audio) * 16000 / sr)) + + # Transcribe with error handling for CUDA issues + try: + # Try with original device + result = asr_model.transcribe(audio, batch_size=8) + except RuntimeError as cuda_error: + if "CUDA" in str(cuda_error) or "libcudnn" in str(cuda_error): + print(f"CUDA error in transcription, falling back to CPU: {str(cuda_error)}") + + # Try to load a CPU model as fallback + try: + global asr_model + # Move model to CPU and try again + asr_model = whisperx.load_model("tiny", "cpu", compute_type="int8") + result = asr_model.transcribe(audio, batch_size=1) + except Exception as e: + print(f"CPU fallback also failed: {str(e)}") + return "I'm having trouble processing audio right now." + else: + # Re-raise if it's not a CUDA error + raise # Clean up if os.path.exists(temp_path): @@ -257,7 +321,7 @@ def transcribe_audio(audio_tensor: torch.Tensor) -> str: traceback.print_exc() if os.path.exists("temp_audio.wav"): os.remove("temp_audio.wav") - return "" + return "I heard something but couldn't understand it." def generate_response(text: str, conversation_history: List[Segment]) -> str: