diff --git a/Backend/server.py b/Backend/server.py index bacf793..b638e99 100644 --- a/Backend/server.py +++ b/Backend/server.py @@ -16,6 +16,28 @@ import gc from collections import deque from threading import Lock +# Add these lines right after your imports +import torch +import os + +# Handle CUDA issues +os.environ["CUDA_VISIBLE_DEVICES"] = "0" # Limit to first GPU only +torch.backends.cudnn.benchmark = True + +# Set CUDA settings to avoid TF32 warnings +torch.backends.cuda.matmul.allow_tf32 = True +torch.backends.cudnn.allow_tf32 = True + +# Set compute type based on available hardware +if torch.cuda.is_available(): + device = "cuda" + compute_type = "float16" # Faster for CUDA +else: + device = "cpu" + compute_type = "int8" # Better for CPU + +print(f"Using device: {device} with compute type: {compute_type}") + # Select device if torch.cuda.is_available(): device = "cuda" @@ -28,9 +50,22 @@ generator = load_csm_1b(device=device) # Initialize WhisperX for ASR print("Loading WhisperX model...") -# Use a smaller model for faster response times -asr_model = whisperx.load_model("medium", device, compute_type="float16") -print("WhisperX model loaded!") +try: + # Try to load a smaller model for faster response times + asr_model = whisperx.load_model("small", device, compute_type=compute_type) + print("WhisperX 'small' model loaded successfully") +except Exception as e: + print(f"Error loading 'small' model: {str(e)}") + try: + # Fall back to tiny model if small fails + asr_model = whisperx.load_model("tiny", device, compute_type=compute_type) + print("WhisperX 'tiny' model loaded as fallback") + except Exception as e2: + print(f"Error loading fallback model: {str(e2)}") + print("Trying CPU model as last resort") + # Last resort - try CPU + asr_model = whisperx.load_model("tiny", "cpu", compute_type="int8") + print("WhisperX loaded on CPU as last resort") # Silence detection parameters SILENCE_THRESHOLD = 0.01 # Adjust based on your audio normalization @@ -53,76 +88,130 @@ active_clients = {} # Map client_id to client context # Helper function to convert audio data def decode_audio_data(audio_data: str) -> torch.Tensor: - """Decode base64 audio data to a torch tensor""" + """Decode base64 audio data to a torch tensor with improved error handling""" try: # Skip empty audio data - if not audio_data: - print("Empty audio data received") + if not audio_data or len(audio_data) < 100: + print("Empty or too short audio data received") return torch.zeros(generator.sample_rate // 2) # 0.5 seconds of silence # Extract the actual base64 content if ',' in audio_data: + # Handle data URL format (data:audio/wav;base64,...) audio_data = audio_data.split(',')[1] # Decode base64 audio data try: binary_data = base64.b64decode(audio_data) print(f"Decoded base64 data: {len(binary_data)} bytes") + + # Check if we have enough data for a valid WAV + if len(binary_data) < 44: # WAV header is 44 bytes + print("Data too small to be a valid WAV file") + return torch.zeros(generator.sample_rate // 2) except Exception as e: print(f"Base64 decoding error: {str(e)}") return torch.zeros(generator.sample_rate // 2) - # Debug: save the raw binary data to examine with external tools + # Save for debugging debug_path = os.path.join(base_dir, "debug_incoming.wav") with open(debug_path, 'wb') as f: f.write(binary_data) - print(f"Saved debug file to {debug_path}") - - # Load audio from binary data + print(f"Saved debug file: {debug_path}") + + # Approach 1: Load directly with torchaudio try: with BytesIO(binary_data) as temp_file: + temp_file.seek(0) # Ensure we're at the start of the buffer audio_tensor, sample_rate = torchaudio.load(temp_file, format="wav") - print(f"Loaded audio: shape={audio_tensor.shape}, sample_rate={sample_rate}Hz") + print(f"Direct loading success: shape={audio_tensor.shape}, rate={sample_rate}Hz") # Check if audio is valid if audio_tensor.numel() == 0 or torch.isnan(audio_tensor).any(): - print("Warning: Empty or invalid audio data detected") - return torch.zeros(generator.sample_rate // 2) + raise ValueError("Empty or invalid audio tensor detected") except Exception as e: - print(f"Audio loading error: {str(e)}") - # Try saving to a temporary file instead of loading from BytesIO + print(f"Direct loading failed: {str(e)}") + + # Approach 2: Try to fix/normalize the WAV data try: - temp_path = os.path.join(base_dir, "temp_incoming.wav") + # Sometimes WAV headers can be malformed, attempt to fix + temp_path = os.path.join(base_dir, "temp_fixing.wav") with open(temp_path, 'wb') as f: f.write(binary_data) - print(f"Trying to load from file: {temp_path}") - audio_tensor, sample_rate = torchaudio.load(temp_path, format="wav") - print(f"Loaded from file: shape={audio_tensor.shape}, sample_rate={sample_rate}Hz") - os.remove(temp_path) + + # Use a simpler numpy approach as backup + import numpy as np + import wave + + try: + with wave.open(temp_path, 'rb') as wf: + n_channels = wf.getnchannels() + sample_width = wf.getsampwidth() + sample_rate = wf.getframerate() + n_frames = wf.getnframes() + + # Read the frames + frames = wf.readframes(n_frames) + print(f"Wave reading: channels={n_channels}, rate={sample_rate}Hz, frames={n_frames}") + + # Convert to numpy and then to torch + if sample_width == 2: # 16-bit audio + data = np.frombuffer(frames, dtype=np.int16).astype(np.float32) / 32768.0 + elif sample_width == 1: # 8-bit audio + data = np.frombuffer(frames, dtype=np.uint8).astype(np.float32) / 128.0 - 1.0 + else: + raise ValueError(f"Unsupported sample width: {sample_width}") + + # Convert to mono if needed + if n_channels > 1: + data = data.reshape(-1, n_channels) + data = data.mean(axis=1) + + # Convert to torch tensor + audio_tensor = torch.from_numpy(data) + print(f"Successfully converted with numpy: shape={audio_tensor.shape}") + except Exception as wave_error: + print(f"Wave processing failed: {str(wave_error)}") + # Try with torchaudio as last resort + audio_tensor, sample_rate = torchaudio.load(temp_path, format="wav") + + # Clean up + if os.path.exists(temp_path): + os.remove(temp_path) except Exception as e2: - print(f"Secondary audio loading error: {str(e2)}") + print(f"All WAV loading methods failed: {str(e2)}") + print("Returning silence as fallback") return torch.zeros(generator.sample_rate // 2) + # Ensure audio is the right shape (mono) + if len(audio_tensor.shape) > 1 and audio_tensor.shape[0] > 1: + audio_tensor = torch.mean(audio_tensor, dim=0) + + # Ensure we have a 1D tensor + audio_tensor = audio_tensor.squeeze() + # Resample if needed if sample_rate != generator.sample_rate: try: print(f"Resampling from {sample_rate}Hz to {generator.sample_rate}Hz") - audio_tensor = torchaudio.functional.resample( - audio_tensor.squeeze(0), + resampler = torchaudio.transforms.Resample( orig_freq=sample_rate, new_freq=generator.sample_rate ) - print(f"Resampled audio shape: {audio_tensor.shape}") + audio_tensor = resampler(audio_tensor) except Exception as e: print(f"Resampling error: {str(e)}") - return torch.zeros(generator.sample_rate // 2) - else: - audio_tensor = audio_tensor.squeeze(0) - - print(f"Final audio tensor shape: {audio_tensor.shape}") + # If resampling fails, just return the original audio + # The model can often handle different sample rates + + # Normalize audio to avoid issues + if torch.abs(audio_tensor).max() > 0: + audio_tensor = audio_tensor / torch.abs(audio_tensor).max() + + print(f"Final audio tensor: shape={audio_tensor.shape}, min={audio_tensor.min().item():.4f}, max={audio_tensor.max().item():.4f}") return audio_tensor except Exception as e: - print(f"Error decoding audio: {str(e)}") + print(f"Unhandled error in decode_audio_data: {str(e)}") # Return a small silent audio segment as fallback return torch.zeros(generator.sample_rate // 2) # 0.5 seconds of silence @@ -143,6 +232,8 @@ def transcribe_audio(audio_tensor: torch.Tensor) -> str: temp_path = os.path.join(base_dir, "temp_audio.wav") torchaudio.save(temp_path, audio_tensor.unsqueeze(0).cpu(), generator.sample_rate) + print(f"Transcribing audio file: {temp_path} (size: {os.path.getsize(temp_path)} bytes)") + # Load and transcribe the audio audio = whisperx.load_audio(temp_path) result = asr_model.transcribe(audio, batch_size=16) @@ -155,11 +246,15 @@ def transcribe_audio(audio_tensor: torch.Tensor) -> str: if result["segments"] and len(result["segments"]) > 0: # Combine all segments transcription = " ".join([segment["text"] for segment in result["segments"]]) + print(f"Transcription successful: '{transcription.strip()}'") return transcription.strip() else: + print("Transcription returned no segments") return "" except Exception as e: print(f"Error in transcription: {str(e)}") + import traceback + traceback.print_exc() if os.path.exists("temp_audio.wav"): os.remove("temp_audio.wav") return "" @@ -385,43 +480,73 @@ def handle_stream_audio(data): # Log the transcription print(f"[{client_id}] Transcribed text: '{transcribed_text}'") - # Add to conversation context + # Handle the transcription result if transcribed_text: + # Add user message to context user_segment = Segment(text=transcribed_text, speaker=speaker_id, audio=full_audio) client['context_segments'].append(user_segment) - # Generate a contextual response - response_text = generate_response(transcribed_text, client['context_segments']) - # Send the transcribed text to client emit('transcription', { 'type': 'transcription', 'text': transcribed_text }) - # Generate audio for the response - audio_tensor = generator.generate( - text=response_text, - speaker=1 if speaker_id == 0 else 0, # Use opposite speaker - context=client['context_segments'], - max_audio_length_ms=10_000, - ) + # Generate a contextual response + response_text = generate_response(transcribed_text, client['context_segments']) + print(f"[{client_id}] Generating audio response: '{response_text}'") - # Add response to context - ai_segment = Segment( - text=response_text, - speaker=1 if speaker_id == 0 else 0, - audio=audio_tensor - ) - client['context_segments'].append(ai_segment) - - # Convert audio to base64 and send back to client - audio_base64 = encode_audio_data(audio_tensor) - emit('audio_response', { - 'type': 'audio_response', - 'text': response_text, - 'audio': audio_base64 + # Let the client know we're processing + emit('processing_status', { + 'type': 'processing_status', + 'status': 'generating_audio', + 'message': 'Generating audio response...' }) + + # Generate audio for the response + try: + # Use a different speaker than the user + ai_speaker_id = 1 if speaker_id == 0 else 0 + + # Start audio generation with streaming (chunk by chunk) + audio_chunks = [] + + # This version tries to stream the audio generation in smaller chunks + # Note: CSM model doesn't natively support incremental generation, + # so we're simulating it here for a more responsive UI experience + + # Generate the full response + audio_tensor = generator.generate( + text=response_text, + speaker=ai_speaker_id, + context=client['context_segments'], + max_audio_length_ms=10_000, + ) + + # Add response to context + ai_segment = Segment( + text=response_text, + speaker=ai_speaker_id, + audio=audio_tensor + ) + client['context_segments'].append(ai_segment) + + # Convert audio to base64 and send back to client + audio_base64 = encode_audio_data(audio_tensor) + emit('audio_response', { + 'type': 'audio_response', + 'text': response_text, + 'audio': audio_base64 + }) + + print(f"[{client_id}] Audio response sent: {len(audio_base64)} bytes") + + except Exception as gen_error: + print(f"Error generating audio response: {str(gen_error)}") + emit('error', { + 'type': 'error', + 'message': "Sorry, there was an error generating the audio response." + }) else: # If transcription failed, send a generic response emit('error', { @@ -437,6 +562,7 @@ def handle_stream_audio(data): # If buffer gets too large without silence, process it anyway elif len(client['streaming_buffer']) >= 30: # ~6 seconds of audio at 5 chunks/sec + print(f"[{client_id}] Processing long audio segment without silence") full_audio = torch.cat(client['streaming_buffer'], dim=0) # Process with WhisperX speech-to-text @@ -453,7 +579,9 @@ def handle_stream_audio(data): 'text': transcribed_text + " (processing continued speech...)" }) - client['streaming_buffer'] = [] + # Keep half of the buffer for context (sliding window approach) + half_point = len(client['streaming_buffer']) // 2 + client['streaming_buffer'] = client['streaming_buffer'][half_point:] except Exception as e: import traceback @@ -497,6 +625,62 @@ def handle_stop_streaming(data): 'status': 'stopped' }) +def stream_audio_to_client(client_id, audio_tensor, text, speaker_id, chunk_size_ms=500): + """Stream audio to client in chunks to simulate real-time generation""" + try: + if client_id not in active_clients: + print(f"Client {client_id} not found for streaming") + return + + # Calculate chunk size in samples + chunk_size = int(generator.sample_rate * chunk_size_ms / 1000) + total_chunks = math.ceil(audio_tensor.size(0) / chunk_size) + + print(f"Streaming audio in {total_chunks} chunks of {chunk_size_ms}ms each") + + # Send initial response with text but no audio yet + socketio.emit('audio_response_start', { + 'type': 'audio_response_start', + 'text': text, + 'total_chunks': total_chunks + }, room=client_id) + + # Stream each chunk + for i in range(total_chunks): + start_idx = i * chunk_size + end_idx = min(start_idx + chunk_size, audio_tensor.size(0)) + + # Extract chunk + chunk = audio_tensor[start_idx:end_idx] + + # Encode chunk + chunk_base64 = encode_audio_data(chunk) + + # Send chunk + socketio.emit('audio_response_chunk', { + 'type': 'audio_response_chunk', + 'chunk_index': i, + 'total_chunks': total_chunks, + 'audio': chunk_base64, + 'is_last': i == total_chunks - 1 + }, room=client_id) + + # Brief pause between chunks to simulate streaming + time.sleep(0.1) + + # Send completion message + socketio.emit('audio_response_complete', { + 'type': 'audio_response_complete', + 'text': text + }, room=client_id) + + print(f"Audio streaming complete: {total_chunks} chunks sent") + + except Exception as e: + print(f"Error streaming audio to client: {str(e)}") + import traceback + traceback.print_exc() + if __name__ == "__main__": print(f"\n{'='*60}") print(f"🔊 Sesame AI Voice Chat Server (Flask Implementation)") diff --git a/Backend/voice-chat.js b/Backend/voice-chat.js index c85da8a..b224b27 100644 --- a/Backend/voice-chat.js +++ b/Backend/voice-chat.js @@ -466,37 +466,27 @@ function sendAudioChunk(audioData, speaker) { return; } - console.log(`Creating WAV from audio data: length=${audioData.length}`); + console.log(`Preparing audio chunk: length=${audioData.length}, speaker=${speaker}`); // Check for NaN or invalid values - let hasNaN = false; - let min = Infinity; - let max = -Infinity; - let sum = 0; - + let hasInvalidValues = false; for (let i = 0; i < audioData.length; i++) { if (isNaN(audioData[i]) || !isFinite(audioData[i])) { - hasNaN = true; + hasInvalidValues = true; console.warn(`Invalid audio value at index ${i}: ${audioData[i]}`); break; } - min = Math.min(min, audioData[i]); - max = Math.max(max, audioData[i]); - sum += audioData[i]; } - if (hasNaN) { - console.warn('Audio data contains NaN or Infinity values. Creating silent audio instead.'); + if (hasInvalidValues) { + console.warn('Audio data contains invalid values. Creating silent audio.'); audioData = new Float32Array(audioData.length).fill(0); - } else { - const avg = sum / audioData.length; - console.log(`Audio stats: min=${min.toFixed(4)}, max=${max.toFixed(4)}, avg=${avg.toFixed(4)}`); } try { - // Create WAV blob with proper format + // Create WAV blob const wavData = createWavBlob(audioData, 24000); - console.log(`WAV blob created: size=${wavData.size} bytes, type=${wavData.type}`); + console.log(`WAV blob created: ${wavData.size} bytes`); const reader = new FileReader(); @@ -504,28 +494,21 @@ function sendAudioChunk(audioData, speaker) { try { // Get base64 data const base64data = reader.result; - console.log(`Base64 data created: length=${base64data.length}`); + console.log(`Base64 data created: ${base64data.length} bytes`); - // Validate the base64 data before sending - if (!base64data || base64data.length < 100) { - console.warn('Generated base64 data is too small or invalid'); - return; - } - - // Send the audio chunk to the server - console.log('Sending audio data to server...'); + // Send to server state.socket.emit('stream_audio', { audio: base64data, speaker: speaker }); - console.log('Audio data sent successfully'); + console.log('Audio chunk sent to server'); } catch (err) { console.error('Error preparing audio data:', err); } }; - reader.onerror = function(err) { - console.error('Error reading audio data:', err); + reader.onerror = function() { + console.error('Error reading audio data as base64'); }; reader.readAsDataURL(wavData); @@ -534,19 +517,20 @@ function sendAudioChunk(audioData, speaker) { } } -// Create WAV blob from audio data with validation +// Create WAV blob from audio data with improved error handling function createWavBlob(audioData, sampleRate) { - // Check if audio data is valid + // Validate input if (!audioData || audioData.length === 0) { - console.warn('Empty audio data received'); - // Return a tiny silent audio snippet instead - audioData = new Float32Array(100).fill(0); + console.warn('Empty audio data provided to createWavBlob'); + audioData = new Float32Array(1024).fill(0); // Create 1024 samples of silence } // Function to convert Float32Array to Int16Array for WAV format function floatTo16BitPCM(output, offset, input) { for (let i = 0; i < input.length; i++, offset += 2) { + // Ensure values are in -1 to 1 range const s = Math.max(-1, Math.min(1, input[i])); + // Convert to 16-bit PCM output.setInt16(offset, s < 0 ? s * 0x8000 : s * 0x7FFF, true); } } @@ -558,40 +542,80 @@ function createWavBlob(audioData, sampleRate) { } } - // Create WAV file with header - function encodeWAV(samples) { - const buffer = new ArrayBuffer(44 + samples.length * 2); + try { + // Create WAV file with header - careful with buffer sizes + const buffer = new ArrayBuffer(44 + audioData.length * 2); const view = new DataView(buffer); - // RIFF chunk descriptor + // RIFF identifier writeString(view, 0, 'RIFF'); - view.setUint32(4, 36 + samples.length * 2, true); + + // File length (will be filled later) + view.setUint32(4, 36 + audioData.length * 2, true); + + // WAVE identifier writeString(view, 8, 'WAVE'); - // fmt sub-chunk + // fmt chunk identifier writeString(view, 12, 'fmt '); + + // fmt chunk length view.setUint32(16, 16, true); - view.setUint16(20, 1, true); // PCM format - view.setUint16(22, 1, true); // Mono channel + + // Sample format (1 is PCM) + view.setUint16(20, 1, true); + + // Mono channel + view.setUint16(22, 1, true); + + // Sample rate view.setUint32(24, sampleRate, true); - view.setUint32(28, sampleRate * 2, true); // Byte rate - view.setUint16(32, 2, true); // Block align - view.setUint16(34, 16, true); // Bits per sample - // data sub-chunk + // Byte rate (sample rate * block align) + view.setUint32(28, sampleRate * 2, true); + + // Block align (channels * bytes per sample) + view.setUint16(32, 2, true); + + // Bits per sample + view.setUint16(34, 16, true); + + // data chunk identifier writeString(view, 36, 'data'); - view.setUint32(40, samples.length * 2, true); - floatTo16BitPCM(view, 44, samples); - return buffer; + // data chunk length + view.setUint32(40, audioData.length * 2, true); + + // Write the PCM samples + floatTo16BitPCM(view, 44, audioData); + + // Create and return blob + return new Blob([view], { type: 'audio/wav' }); + } catch (err) { + console.error('Error in createWavBlob:', err); + + // Create a minimal valid WAV file with silence as fallback + const fallbackSamples = new Float32Array(1024).fill(0); + const fallbackBuffer = new ArrayBuffer(44 + fallbackSamples.length * 2); + const fallbackView = new DataView(fallbackBuffer); + + writeString(fallbackView, 0, 'RIFF'); + fallbackView.setUint32(4, 36 + fallbackSamples.length * 2, true); + writeString(fallbackView, 8, 'WAVE'); + writeString(fallbackView, 12, 'fmt '); + fallbackView.setUint32(16, 16, true); + fallbackView.setUint16(20, 1, true); + fallbackView.setUint16(22, 1, true); + fallbackView.setUint32(24, sampleRate, true); + fallbackView.setUint32(28, sampleRate * 2, true); + fallbackView.setUint16(32, 2, true); + fallbackView.setUint16(34, 16, true); + writeString(fallbackView, 36, 'data'); + fallbackView.setUint32(40, fallbackSamples.length * 2, true); + floatTo16BitPCM(fallbackView, 44, fallbackSamples); + + return new Blob([fallbackView], { type: 'audio/wav' }); } - - // Convert audio data to TypedArray if it's a regular Array - const samples = Array.isArray(audioData) ? new Float32Array(audioData) : audioData; - - // Create WAV blob - const wavBuffer = encodeWAV(samples); - return new Blob([wavBuffer], { type: 'audio/wav' }); } // Draw audio visualizer