diff --git a/Backend/server.py b/Backend/server.py index 8f4e278..cfcc6ea 100644 --- a/Backend/server.py +++ b/Backend/server.py @@ -16,6 +16,8 @@ from flask_socketio import SocketIO, emit, disconnect from generator import load_csm_1b, Segment from collections import deque from threading import Lock +from transformers import pipeline +from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline # Configure logging logging.basicConfig( @@ -84,29 +86,76 @@ logger.info("Initialized simple speech detector") # Model Loading Functions def load_speech_models(): - """Load speech generation model""" - # Load speech generation model (Sesame CSM) - try: - logger.info(f"Loading Sesame CSM model on {device}...") - generator = load_csm_1b(device=device) - logger.info("Sesame CSM model loaded successfully") - except Exception as e: - logger.error(f"Error loading Sesame CSM on {device}: {e}") - if device == "cuda": - try: - logger.info("Trying to load Sesame CSM on CPU instead...") - generator = load_csm_1b(device="cpu") - logger.info("Sesame CSM model loaded on CPU successfully") - except Exception as cpu_error: - logger.critical(f"Failed to load speech synthesis model: {cpu_error}") - raise RuntimeError("Failed to load speech synthesis model") - else: - raise RuntimeError("Failed to load speech synthesis model on any device") + """Load speech generation and recognition models""" + # Load CSM (existing code) + generator = load_csm_1b(device=device) - return generator + # Load Whisper model for speech recognition + try: + logger.info(f"Loading speech recognition model on {device}...") + speech_recognizer = pipeline("automatic-speech-recognition", + model="openai/whisper-small", + device=device) + logger.info("Speech recognition model loaded successfully") + except Exception as e: + logger.error(f"Error loading speech recognition model: {e}") + speech_recognizer = None + + return generator, speech_recognizer -# Load speech model -generator = load_speech_models() +# Unpack both models +generator, speech_recognizer = load_speech_models() + +# Initialize Llama 3.2 model for conversation responses +def load_llm_model(): + """Load Llama 3.2 model for generating text responses""" + try: + logger.info("Loading Llama 3.2 model for conversational responses...") + model_id = "meta-llama/Llama-3.2-1B-Instruct" + tokenizer = AutoTokenizer.from_pretrained(model_id) + + # Determine compute device for LLM + llm_device = "cpu" # Default to CPU for LLM + + # Use CUDA if available and there's enough VRAM + if device == "cuda" and torch.cuda.is_available(): + try: + free_mem = torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_allocated(0) + # If we have at least 2GB free, use CUDA for LLM + if free_mem > 2 * 1024 * 1024 * 1024: + llm_device = "cuda" + except: + pass + + logger.info(f"Using {llm_device} for Llama 3.2 model") + + # Load the model with lower precision for efficiency + model = AutoModelForCausalLM.from_pretrained( + model_id, + torch_dtype=torch.float16 if llm_device == "cuda" else torch.float32, + device_map=llm_device + ) + + # Create a pipeline for easier inference + llm = pipeline( + "text-generation", + model=model, + tokenizer=tokenizer, + max_length=512, + do_sample=True, + temperature=0.7, + top_p=0.9, + repetition_penalty=1.1 + ) + + logger.info("Llama 3.2 model loaded successfully") + return llm + except Exception as e: + logger.error(f"Error loading Llama 3.2 model: {e}") + return None + +# Load the LLM model +llm = load_llm_model() # Set up Flask and Socket.IO app = Flask(__name__) @@ -249,26 +298,116 @@ def encode_audio_data(audio_tensor: torch.Tensor) -> str: return f"data:audio/wav;base64,{base64.b64encode(buf.read()).decode('utf-8')}" def process_speech(audio_tensor: torch.Tensor, client_id: str) -> str: - """Process speech and return a simple response""" - # In this simplified version, we'll just check if there's sound - # and provide basic responses instead of doing actual speech recognition + """Process speech with speech recognition""" + if not speech_recognizer: + # Fallback to basic detection if model failed to load + return detect_speech_energy(audio_tensor) - if speech_detector and speech_detector.detect_speech(audio_tensor, generator.sample_rate): - # Generate a response based on audio energy - energy = torch.mean(torch.abs(audio_tensor)).item() + try: + # Save audio to temp file for Whisper + temp_path = os.path.join(base_dir, f"temp_{time.time()}.wav") + torchaudio.save(temp_path, audio_tensor.unsqueeze(0).cpu(), generator.sample_rate) - if energy > 0.1: # Louder speech - return "I heard you speaking clearly. How can I help you today?" - elif energy > 0.05: # Moderate speech - return "I heard you say something. Could you please repeat that?" - else: # Soft speech - return "I detected some speech, but it was quite soft. Could you speak up a bit?" - else: + # Perform speech recognition + result = speech_recognizer(temp_path) + transcription = result["text"] + + # Clean up temp file + if os.path.exists(temp_path): + os.remove(temp_path) + + # Return empty string if no speech detected + if not transcription or transcription.isspace(): + return "I didn't detect any speech. Could you please try again?" + + return transcription + + except Exception as e: + logger.error(f"Speech recognition error: {e}") + return "Sorry, I couldn't understand what you said. Could you try again?" + +def detect_speech_energy(audio_tensor: torch.Tensor) -> str: + """Basic speech detection based on audio energy levels""" + # Calculate audio energy + energy = torch.mean(torch.abs(audio_tensor)).item() + + logger.debug(f"Audio energy detected: {energy:.6f}") + + # Generate response based on energy level + if energy > 0.1: # Louder speech + return "I heard you speaking clearly. How can I help you today?" + elif energy > 0.05: # Moderate speech + return "I heard you say something. Could you please repeat that?" + elif energy > 0.02: # Soft speech + return "I detected some speech, but it was quite soft. Could you speak up a bit?" + else: # Very soft or no speech return "I didn't detect any speech. Could you please try again?" def generate_response(text: str, conversation_history: List[Segment]) -> str: - """Generate a contextual response based on the transcribed text""" - # Simple response logic - can be replaced with a more sophisticated LLM + """Generate a contextual response based on the transcribed text using Llama 3.2""" + # If LLM is not available, use simple responses + if llm is None: + return generate_simple_response(text) + + try: + # Create a conversational prompt based on history + # Format: recent conversation turns (up to 4) + current user input + history_str = "" + + # Add up to 4 recent conversation turns (excluding the current one) + recent_segments = [ + seg for seg in conversation_history[-8:] + if seg.text and not seg.text.isspace() + ] + + for i, segment in enumerate(recent_segments): + speaker_name = "User" if segment.speaker == 0 else "Assistant" + history_str += f"{speaker_name}: {segment.text}\n" + + # Construct the prompt for Llama 3.2 + prompt = f"""<|system|> +You are Sesame, a helpful, friendly and concise voice assistant. +Keep your responses conversational, natural, and to the point. +Respond to the user's latest message in the context of the conversation. +<|end|> + +{history_str} +User: {text} +Assistant:""" + + logger.debug(f"LLM Prompt: {prompt}") + + # Generate response with the LLM + result = llm( + prompt, + max_new_tokens=150, + do_sample=True, + temperature=0.7, + top_p=0.9, + repetition_penalty=1.1 + ) + + # Extract the generated text + response = result[0]["generated_text"] + + # Extract just the Assistant's response (after the prompt) + response = response.split("Assistant:")[-1].strip() + + # Clean up and ensure it's not too long for TTS + response = response.split("\n")[0].strip() + if len(response) > 200: + response = response[:197] + "..." + + logger.info(f"LLM response: {response}") + return response + + except Exception as e: + logger.error(f"Error generating LLM response: {e}") + # Fall back to simple responses + return generate_simple_response(text) + +def generate_simple_response(text: str) -> str: + """Generate a simple rule-based response as fallback""" responses = { "hello": "Hello there! How can I help you today?", "hi": "Hi there! What can I do for you?", @@ -295,7 +434,7 @@ def generate_response(text: str, conversation_history: List[Segment]) -> str: elif len(text) < 10: return "Thanks for your message. Could you elaborate a bit more?" else: - return f"I heard you speaking. That's interesting! Can you tell me more about that?" + return f"I heard you say something about that. Can you tell me more?" # Flask Routes @app.route('/')