diff --git a/Backend/server.py b/Backend/server.py index 978b87c..352f5cd 100644 --- a/Backend/server.py +++ b/Backend/server.py @@ -8,9 +8,17 @@ import numpy as np from flask import Flask, render_template, request from flask_socketio import SocketIO, emit from transformers import AutoModelForCausalLM, AutoTokenizer -from faster_whisper import WhisperModel -from generator import load_csm_1b, Segment from collections import deque +import requests +import huggingface_hub +from generator import load_csm_1b, Segment + +# Configure environment with longer timeouts +os.environ["HF_HUB_DOWNLOAD_TIMEOUT"] = "600" # 10 minutes timeout for downloads +requests.adapters.DEFAULT_TIMEOUT = 60 # Increase default requests timeout + +# Create a models directory for caching +os.makedirs("models", exist_ok=True) app = Flask(__name__) app.config['SECRET_KEY'] = 'your-secret-key' @@ -29,23 +37,50 @@ else: print(f"Using device: {device}") -# Initialize Faster-Whisper for transcription -print("Loading Whisper model...") -whisper_model = WhisperModel("base", device=device, compute_type=whisper_compute_type) +# Initialize models with proper error handling +whisper_model = None +csm_generator = None +llm_model = None +llm_tokenizer = None -# Initialize CSM model for audio generation -print("Loading CSM model...") -csm_generator = load_csm_1b(device=device) - -# Initialize Llama 3.2 model for response generation -print("Loading Llama 3.2 model...") -llm_model_id = "meta-llama/Llama-3.2-1B" # Choose appropriate size based on resources -llm_tokenizer = AutoTokenizer.from_pretrained(llm_model_id) -llm_model = AutoModelForCausalLM.from_pretrained( - llm_model_id, - torch_dtype=torch.bfloat16, - device_map=device -) +def load_models(): + global whisper_model, csm_generator, llm_model, llm_tokenizer + + # Initialize Faster-Whisper for transcription + try: + print("Loading Whisper model...") + # Import here to avoid immediate import errors if package is missing + from faster_whisper import WhisperModel + whisper_model = WhisperModel("base", device=device, compute_type=whisper_compute_type, download_root="./models/whisper") + print("Whisper model loaded successfully") + except Exception as e: + print(f"Error loading Whisper model: {e}") + print("Will use backup speech recognition method if available") + + # Initialize CSM model for audio generation + try: + print("Loading CSM model...") + csm_generator = load_csm_1b(device=device) + print("CSM model loaded successfully") + except Exception as e: + print(f"Error loading CSM model: {e}") + print("Audio generation will not be available") + + # Initialize Llama 3.2 model for response generation + try: + print("Loading Llama 3.2 model...") + llm_model_id = "meta-llama/Llama-3.2-1B" # Choose appropriate size based on resources + llm_tokenizer = AutoTokenizer.from_pretrained(llm_model_id, cache_dir="./models/llama") + llm_model = AutoModelForCausalLM.from_pretrained( + llm_model_id, + torch_dtype=torch.bfloat16, + device_map=device, + cache_dir="./models/llama" + ) + print("Llama 3.2 model loaded successfully") + except Exception as e: + print(f"Error loading Llama 3.2 model: {e}") + print("Will use a fallback response generation method") # Store conversation context conversation_context = {} # session_id -> context @@ -128,7 +163,7 @@ def process_user_utterance(session_id): context['is_speaking'] = False context['silence_start'] = None - # Save audio to temporary WAV file for Whisper transcription + # Save audio to temporary WAV file for transcription temp_audio_path = f"temp_audio_{session_id}.wav" torchaudio.save( temp_audio_path, @@ -136,25 +171,17 @@ def process_user_utterance(session_id): 44100 # Assuming 44.1kHz from client ) - # Transcribe speech using Faster-Whisper try: - segments, info = whisper_model.transcribe(temp_audio_path, beam_size=5) + # Try using Whisper first if available + if whisper_model is not None: + user_text = transcribe_with_whisper(temp_audio_path) + else: + # Fallback to Google's speech recognition + user_text = transcribe_with_google(temp_audio_path) - # Collect all text from segments - user_text = "" - for segment in segments: - segment_text = segment.text.strip() - print(f"[{segment.start:.2f}s -> {segment.end:.2f}s] {segment_text}") - user_text += segment_text + " " - - user_text = user_text.strip() - - # Cleanup temp file - if os.path.exists(temp_audio_path): - os.remove(temp_audio_path) - if not user_text: print("No speech detected.") + emit('error', {'message': 'No speech detected. Please try again.'}, room=session_id) return print(f"Transcribed: {user_text}") @@ -171,79 +198,158 @@ def process_user_utterance(session_id): bot_response = generate_llm_response(user_text, context['segments']) print(f"Bot response: {bot_response}") - # Convert to audio using CSM - bot_audio = generate_audio_response(bot_response, context['segments']) - - # Convert audio to base64 for sending over websocket - audio_bytes = io.BytesIO() - torchaudio.save(audio_bytes, bot_audio.unsqueeze(0).cpu(), csm_generator.sample_rate, format="wav") - audio_bytes.seek(0) - audio_b64 = base64.b64encode(audio_bytes.read()).decode('utf-8') - - # Add bot response to conversation history - bot_segment = Segment( - text=bot_response, - speaker=1, # Bot is speaker 1 - audio=bot_audio - ) - context['segments'].append(bot_segment) - # Send transcribed text to client emit('transcription', {'text': user_text}, room=session_id) - # Send audio response to client - emit('audio_response', { - 'audio': audio_b64, - 'text': bot_response - }, room=session_id) + # Generate and send audio response if CSM is available + if csm_generator is not None: + # Convert to audio using CSM + bot_audio = generate_audio_response(bot_response, context['segments']) + + # Convert audio to base64 for sending over websocket + audio_bytes = io.BytesIO() + torchaudio.save(audio_bytes, bot_audio.unsqueeze(0).cpu(), csm_generator.sample_rate, format="wav") + audio_bytes.seek(0) + audio_b64 = base64.b64encode(audio_bytes.read()).decode('utf-8') + + # Add bot response to conversation history + bot_segment = Segment( + text=bot_response, + speaker=1, # Bot is speaker 1 + audio=bot_audio + ) + context['segments'].append(bot_segment) + + # Send audio response to client + emit('audio_response', { + 'audio': audio_b64, + 'text': bot_response + }, room=session_id) + else: + # Send text-only response if audio generation isn't available + emit('text_response', {'text': bot_response}, room=session_id) + + # Add text-only bot response to conversation history + bot_segment = Segment( + text=bot_response, + speaker=1, # Bot is speaker 1 + audio=torch.zeros(1) # Placeholder empty audio + ) + context['segments'].append(bot_segment) except Exception as e: print(f"Error processing speech: {e}") emit('error', {'message': f'Error processing speech: {str(e)}'}, room=session_id) - # Cleanup temp file in case of error + finally: + # Cleanup temp file if os.path.exists(temp_audio_path): os.remove(temp_audio_path) +def transcribe_with_whisper(audio_path): + """Transcribe audio using Faster-Whisper""" + segments, info = whisper_model.transcribe(audio_path, beam_size=5) + + # Collect all text from segments + user_text = "" + for segment in segments: + segment_text = segment.text.strip() + print(f"[{segment.start:.2f}s -> {segment.end:.2f}s] {segment_text}") + user_text += segment_text + " " + + return user_text.strip() + +def transcribe_with_google(audio_path): + """Fallback transcription using Google's speech recognition""" + import speech_recognition as sr + recognizer = sr.Recognizer() + + with sr.AudioFile(audio_path) as source: + audio = recognizer.record(source) + try: + text = recognizer.recognize_google(audio) + return text + except sr.UnknownValueError: + return "" + except sr.RequestError: + # If Google API fails, try a basic energy-based VAD approach + # This is a very basic fallback and won't give good results + return "[Speech detected but transcription failed]" + def generate_llm_response(user_text, conversation_segments): - """Generate text response using Llama 3.2""" - # Format conversation history for the LLM - conversation_history = "" - for segment in conversation_segments[-5:]: # Use last 5 utterances for context - speaker_name = "User" if segment.speaker == 0 else "Assistant" - conversation_history += f"{speaker_name}: {segment.text}\n" + """Generate text response using available model""" + if llm_model is not None and llm_tokenizer is not None: + # Format conversation history for the LLM + conversation_history = "" + for segment in conversation_segments[-5:]: # Use last 5 utterances for context + speaker_name = "User" if segment.speaker == 0 else "Assistant" + conversation_history += f"{speaker_name}: {segment.text}\n" + + # Add the current user query + conversation_history += f"User: {user_text}\nAssistant:" + + try: + # Generate response + inputs = llm_tokenizer(conversation_history, return_tensors="pt").to(device) + output = llm_model.generate( + inputs.input_ids, + max_new_tokens=150, + temperature=0.7, + top_p=0.9, + do_sample=True + ) + + response = llm_tokenizer.decode(output[0][inputs.input_ids.shape[1]:], skip_special_tokens=True) + return response.strip() + except Exception as e: + print(f"Error generating response with LLM: {e}") + return fallback_response(user_text) + else: + return fallback_response(user_text) + +def fallback_response(user_text): + """Generate a simple fallback response when LLM is not available""" + # Simple rule-based responses + user_text_lower = user_text.lower() - # Add the current user query - conversation_history += f"User: {user_text}\nAssistant:" + if "hello" in user_text_lower or "hi" in user_text_lower: + return "Hello! I'm a simple fallback assistant. The main language model couldn't be loaded, so I have limited capabilities." - # Generate response - inputs = llm_tokenizer(conversation_history, return_tensors="pt").to(device) - output = llm_model.generate( - inputs.input_ids, - max_new_tokens=150, - temperature=0.7, - top_p=0.9, - do_sample=True - ) + elif "how are you" in user_text_lower: + return "I'm functioning within my limited capabilities. How can I assist you today?" - response = llm_tokenizer.decode(output[0][inputs.input_ids.shape[1]:], skip_special_tokens=True) - return response.strip() + elif "thank" in user_text_lower: + return "You're welcome! Let me know if there's anything else I can help with." + + elif "bye" in user_text_lower or "goodbye" in user_text_lower: + return "Goodbye! Have a great day!" + + elif any(q in user_text_lower for q in ["what", "who", "where", "when", "why", "how"]): + return "I'm running in fallback mode and can't answer complex questions. Please try again when the main language model is available." + + else: + return "I understand you said something about that. Unfortunately, I'm running in fallback mode with limited capabilities. Please try again later when the main model is available." def generate_audio_response(text, conversation_segments): """Generate audio response using CSM""" - # Use the last few conversation segments as context - context_segments = conversation_segments[-4:] if len(conversation_segments) > 4 else conversation_segments - - # Generate audio for bot response - audio = csm_generator.generate( - text=text, - speaker=1, # Bot is speaker 1 - context=context_segments, - max_audio_length_ms=10000, # 10 seconds max - temperature=0.9, - topk=50 - ) - - return audio + try: + # Use the last few conversation segments as context + context_segments = conversation_segments[-4:] if len(conversation_segments) > 4 else conversation_segments + + # Generate audio for bot response + audio = csm_generator.generate( + text=text, + speaker=1, # Bot is speaker 1 + context=context_segments, + max_audio_length_ms=10000, # 10 seconds max + temperature=0.9, + topk=50 + ) + + return audio + except Exception as e: + print(f"Error generating audio: {e}") + # Return silence as fallback + return torch.zeros(csm_generator.sample_rate * 3) # 3 seconds of silence if __name__ == '__main__': # Ensure the existing index.html file is in the correct location @@ -253,4 +359,11 @@ if __name__ == '__main__': if os.path.exists('index.html') and not os.path.exists('templates/index.html'): os.rename('index.html', 'templates/index.html') + # Load models asynchronously before starting the server + print("Starting model loading...") + # In a production environment, you could load models in a separate thread + load_models() + + # Start the server + print("Starting Flask SocketIO server...") socketio.run(app, host='0.0.0.0', port=5000, debug=False) \ No newline at end of file