diff --git a/Backend/server.py b/Backend/server.py index 5fbe12a..52a85e9 100644 --- a/Backend/server.py +++ b/Backend/server.py @@ -61,22 +61,42 @@ def load_models(): global models logger.info("Loading CSM 1B model...") - models.generator = load_csm_1b(device=DEVICE) + try: + models.generator = load_csm_1b(device=DEVICE) + logger.info("CSM 1B model loaded successfully") + socketio.emit('model_status', {'model': 'csm', 'status': 'loaded'}) + except Exception as e: + logger.error(f"Error loading CSM 1B model: {str(e)}") + socketio.emit('model_status', {'model': 'csm', 'status': 'error', 'message': str(e)}) logger.info("Loading ASR pipeline...") - models.asr = pipeline( - "automatic-speech-recognition", - model="openai/whisper-small", - device=DEVICE - ) + try: + models.asr = pipeline( + "automatic-speech-recognition", + model="openai/whisper-small", + device=DEVICE, + language="en", # Force English language + return_attention_mask=True # Add attention mask + ) + logger.info("ASR pipeline loaded successfully") + socketio.emit('model_status', {'model': 'asr', 'status': 'loaded'}) + except Exception as e: + logger.error(f"Error loading ASR pipeline: {str(e)}") + socketio.emit('model_status', {'model': 'asr', 'status': 'error', 'message': str(e)}) logger.info("Loading Llama 3.2 model...") - models.llm = AutoModelForCausalLM.from_pretrained( - "meta-llama/Llama-3.2-1B", - device_map=DEVICE, - torch_dtype=torch.bfloat16 - ) - models.tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B") + try: + models.llm = AutoModelForCausalLM.from_pretrained( + "meta-llama/Llama-3.2-1B", + device_map=DEVICE, + torch_dtype=torch.bfloat16 + ) + models.tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B") + logger.info("Llama 3.2 model loaded successfully") + socketio.emit('model_status', {'model': 'llm', 'status': 'loaded'}) + except Exception as e: + logger.error(f"Error loading Llama 3.2 model: {str(e)}") + socketio.emit('model_status', {'model': 'llm', 'status': 'error', 'message': str(e)}) # Load models in a background thread threading.Thread(target=load_models, daemon=True).start() @@ -118,6 +138,20 @@ def health_check(): "models_loaded": models.generator is not None and models.llm is not None }) +# Add a system status endpoint +@app.route('/api/status') +def system_status(): + return jsonify({ + "status": "ok", + "cuda_available": torch.cuda.is_available(), + "device": DEVICE, + "models": { + "generator": models.generator is not None, + "asr": models.asr is not None, + "llm": models.llm is not None + } + }) + # Socket event handlers @socketio.on('connect') def handle_connect(auth=None): @@ -225,10 +259,12 @@ def process_audio_queue(session_id, q): def process_audio_and_respond(session_id, data): """Process audio data and generate a response""" if models.generator is None or models.asr is None or models.llm is None: + logger.warning("Models not yet loaded!") with app.app_context(): socketio.emit('error', {'message': 'Models still loading, please wait'}, room=session_id) return + logger.info(f"Processing audio for session {session_id}") conversation = active_conversations[session_id] try: @@ -238,9 +274,15 @@ def process_audio_and_respond(session_id, data): # Process base64 audio data audio_data = data['audio'] speaker_id = data['speaker'] + logger.info(f"Received audio from speaker {speaker_id}") # Convert from base64 to WAV - audio_bytes = base64.b64decode(audio_data.split(',')[1]) + try: + audio_bytes = base64.b64decode(audio_data.split(',')[1]) + logger.info(f"Decoded audio bytes: {len(audio_bytes)} bytes") + except Exception as e: + logger.error(f"Error decoding base64 audio: {str(e)}") + raise # Save to temporary file for processing with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as temp_file: @@ -308,11 +350,19 @@ def process_audio_and_respond(session_id, data): prompt = f"{conversation_history}Assistant: " # Generate response with Llama - input_ids = models.tokenizer(prompt, return_tensors="pt").input_ids.to(DEVICE) - + input_tokens = models.tokenizer( + prompt, + return_tensors="pt", + padding=True, + return_attention_mask=True + ) + input_ids = input_tokens.input_ids.to(DEVICE) + attention_mask = input_tokens.attention_mask.to(DEVICE) + with torch.no_grad(): generated_ids = models.llm.generate( input_ids, + attention_mask=attention_mask, # Add the attention mask max_new_tokens=100, temperature=0.7, top_p=0.9, @@ -437,5 +487,6 @@ cleanup_thread.start() # Start the server if __name__ == '__main__': port = int(os.environ.get('PORT', 5000)) - logger.info(f"Starting server on port {port}") - socketio.run(app, host='0.0.0.0', port=port, debug=False, allow_unsafe_werkzeug=True) \ No newline at end of file + debug_mode = os.environ.get('DEBUG', 'False').lower() == 'true' + logger.info(f"Starting server on port {port} (debug={debug_mode})") + socketio.run(app, host='0.0.0.0', port=port, debug=debug_mode, allow_unsafe_werkzeug=True) \ No newline at end of file diff --git a/Backend/voice-chat.js b/Backend/voice-chat.js index 109f426..5c3f247 100644 --- a/Backend/voice-chat.js +++ b/Backend/voice-chat.js @@ -378,25 +378,39 @@ function handleSpeechState(isSilent) { if (state.isSpeaking) { state.isSpeaking = false; - // Get the current audio data and send it - const audioBuffer = new Float32Array(state.audioContext.sampleRate * 5); // 5 seconds max - state.analyser.getFloatTimeDomainData(audioBuffer); - - // Create WAV blob - const wavBlob = createWavBlob(audioBuffer, state.audioContext.sampleRate); - - // Convert to base64 - const reader = new FileReader(); - reader.onloadend = function() { - sendAudioChunk(reader.result, state.currentSpeaker); - }; - reader.readAsDataURL(wavBlob); - - // Update button state - elements.streamButton.classList.add('processing'); - elements.streamButton.innerHTML = ' Processing...'; - - addSystemMessage('Processing your message...'); + try { + // Get the current audio data and send it + const audioBuffer = new Float32Array(state.audioContext.sampleRate * 5); // 5 seconds max + state.analyser.getFloatTimeDomainData(audioBuffer); + + // Check if audio has content + const hasAudioContent = audioBuffer.some(sample => Math.abs(sample) > 0.01); + + if (!hasAudioContent) { + console.warn('Audio buffer appears to be empty or very quiet'); + addSystemMessage('No speech detected. Please try again and speak clearly.'); + return; + } + + // Create WAV blob + const wavBlob = createWavBlob(audioBuffer, state.audioContext.sampleRate); + + // Convert to base64 + const reader = new FileReader(); + reader.onloadend = function() { + sendAudioChunk(reader.result, state.currentSpeaker); + }; + reader.readAsDataURL(wavBlob); + + // Update button state + elements.streamButton.classList.add('processing'); + elements.streamButton.innerHTML = ' Processing...'; + + addSystemMessage('Processing your message...'); + } catch (e) { + console.error('Error recording audio:', e); + addSystemMessage('Error recording audio. Please try again.'); + } } }, CLIENT_SILENCE_DURATION_MS); }