diff --git a/Backend/index.html b/Backend/index.html index 359ed41..6f2a4fb 100644 --- a/Backend/index.html +++ b/Backend/index.html @@ -3,454 +3,266 @@ - Voice Assistant - CSM & Whisper - + AI Voice Chat -

Voice Assistant with CSM & Whisper

-
- -
- -
- - - -
Connecting to server...
+
+
+

AI Voice Assistant

+
- +
+
+
+ Disconnected +
+
+ +
+
+ +
+
Your conversation will appear here.
+
+ + + +
+ + +
+ + + +
+

Status

+
+
+
Whisper Model: Loading...
+
+
+
CSM Audio Model: Loading...
+
+
+
LLM Model: Loading...
+
+
+
WebRTC: Not Connected
+
+
+
+
+ + + + + \ No newline at end of file diff --git a/Backend/server.py b/Backend/server.py index 93fac92..af76560 100644 --- a/Backend/server.py +++ b/Backend/server.py @@ -8,15 +8,14 @@ import numpy as np from flask import Flask, render_template, request from flask_socketio import SocketIO, emit from transformers import AutoModelForCausalLM, AutoTokenizer -from collections import deque +import threading +import queue import requests import huggingface_hub from generator import load_csm_1b, Segment - -# Force CPU mode regardless of what's available -# This bypasses the CUDA/cuDNN library requirements -os.environ["CUDA_VISIBLE_DEVICES"] = "" # Hide all CUDA devices -torch.backends.cudnn.enabled = False # Disable cuDNN +from collections import deque +import json +import webrtcvad # For voice activity detection # Configure environment with longer timeouts os.environ["HF_HUB_DOWNLOAD_TIMEOUT"] = "600" # 10 minutes timeout for downloads @@ -27,28 +26,92 @@ os.makedirs("models", exist_ok=True) app = Flask(__name__) app.config['SECRET_KEY'] = 'your-secret-key' -socketio = SocketIO(app, cors_allowed_origins="*") +socketio = SocketIO(app, cors_allowed_origins="*", async_mode='eventlet') -# Force CPU regardless of what hardware is available -device = "cuda" if torch.cuda.is_available() else "cpu" -whisper_compute_type = "int8" -print(f"Forcing CPU mode for all models") +# Explicitly check for CUDA and print detailed info +print("\n=== CUDA Information ===") +if torch.cuda.is_available(): + print(f"CUDA is available") + print(f"CUDA version: {torch.version.cuda}") + print(f"Number of GPUs: {torch.cuda.device_count()}") + for i in range(torch.cuda.device_count()): + print(f"GPU {i}: {torch.cuda.get_device_name(i)}") +else: + print("CUDA is not available") + +# Check for cuDNN +try: + import ctypes + ctypes.CDLL("libcudnn_ops_infer.so.8") + print("cuDNN is available") +except: + print("cuDNN is not available (libcudnn_ops_infer.so.8 not found)") + +# Determine compute device +try: + if torch.cuda.is_available(): + device = "cuda" + whisper_compute_type = "float16" + print("🟢 CUDA is available and initialized successfully") + elif torch.backends.mps.is_available(): + device = "mps" + whisper_compute_type = "float32" + print("🟢 MPS is available (Apple Silicon)") + else: + device = "cpu" + whisper_compute_type = "int8" + print("🟡 Using CPU (CUDA/MPS not available)") +except Exception as e: + print(f"🔴 Error initializing CUDA: {e}") + print("🔴 Falling back to CPU") + device = "cpu" + whisper_compute_type = "int8" + +print(f"Using device: {device}") # Initialize models with proper error handling whisper_model = None csm_generator = None llm_model = None llm_tokenizer = None +vad = None + +# Constants +SAMPLE_RATE = 16000 # For VAD +VAD_FRAME_SIZE = 480 # 30ms at 16kHz for VAD +VAD_MODE = 3 # Aggressive mode for better results +AUDIO_CHUNK_SIZE = 2400 # 100ms chunks when streaming AI voice + +# Audio sample rates +CLIENT_SAMPLE_RATE = 44100 # Browser WebAudio default +WHISPER_SAMPLE_RATE = 16000 # Whisper expects 16kHz + +# Session data structures +user_sessions = {} # session_id -> complete session data + +# WebRTC ICE servers (STUN/TURN servers for NAT traversal) +ICE_SERVERS = [ + {"urls": "stun:stun.l.google.com:19302"}, + {"urls": "stun:stun1.l.google.com:19302"} +] def load_models(): - global whisper_model, csm_generator, llm_model, llm_tokenizer + """Load all necessary models""" + global whisper_model, csm_generator, llm_model, llm_tokenizer, vad + + # Initialize Voice Activity Detector + try: + vad = webrtcvad.Vad(VAD_MODE) + print("Voice Activity Detector initialized") + except Exception as e: + print(f"Error initializing VAD: {e}") + vad = None # Initialize Faster-Whisper for transcription try: - print("Loading Whisper model on CPU...") - # Import here to avoid immediate import errors if package is missing + print("Loading Whisper model...") from faster_whisper import WhisperModel - whisper_model = WhisperModel("tiny", device="cpu", compute_type="int8", download_root="./models/whisper") + whisper_model = WhisperModel("base", device=device, compute_type=whisper_compute_type, download_root="./models/whisper") print("Whisper model loaded successfully") except Exception as e: print(f"Error loading Whisper model: {e}") @@ -56,8 +119,8 @@ def load_models(): # Initialize CSM model for audio generation try: - print("Loading CSM model on CPU...") - csm_generator = load_csm_1b(device="cpu") + print("Loading CSM model...") + csm_generator = load_csm_1b(device=device) print("CSM model loaded successfully") except Exception as e: print(f"Error loading CSM model: {e}") @@ -65,13 +128,14 @@ def load_models(): # Initialize Llama 3.2 model for response generation try: - print("Loading Llama 3.2 model on CPU...") - llm_model_id = "meta-llama/Llama-3.2-1B" # Choose appropriate size based on resources + print("Loading Llama 3.2 model...") + llm_model_id = "meta-llama/Llama-3.2-1B" llm_tokenizer = AutoTokenizer.from_pretrained(llm_model_id, cache_dir="./models/llama") + dtype = torch.bfloat16 if device != "cpu" else torch.float32 llm_model = AutoModelForCausalLM.from_pretrained( llm_model_id, - torch_dtype=torch.float32, # Use float32 on CPU - device_map="cpu", + torch_dtype=dtype, + device_map=device, cache_dir="./models/llama", low_cpu_mem_usage=True ) @@ -80,168 +144,344 @@ def load_models(): print(f"Error loading Llama 3.2 model: {e}") print("Will use a fallback response generation method") -# Store conversation context -conversation_context = {} # session_id -> context - @app.route('/') def index(): + """Serve the main interface""" return render_template('index.html') +@app.route('/voice-chat.js') +def voice_chat_js(): + """Serve the JavaScript for voice chat""" + return app.send_static_file('voice-chat.js') + @socketio.on('connect') def handle_connect(): - print(f"Client connected: {request.sid}") - conversation_context[request.sid] = { + """Handle new client connection""" + session_id = request.sid + print(f"Client connected: {session_id}") + + # Initialize session data + user_sessions[session_id] = { + # Conversation context 'segments': [], - 'speakers': [0, 1], # 0 = user, 1 = bot - 'audio_buffer': deque(maxlen=10), # Store recent audio chunks - 'is_speaking': False, - 'silence_start': None + 'conversation_history': [], + 'is_turn_active': False, + + # Audio buffers and state + 'vad_buffer': deque(maxlen=30), # ~1s of audio at 30fps + 'audio_buffer': bytearray(), + 'is_user_speaking': False, + 'last_vad_active': time.time(), + 'silence_duration': 0, + 'speech_frames': 0, + + # AI state + 'is_ai_speaking': False, + 'should_interrupt_ai': False, + 'ai_stream_queue': queue.Queue(), + + # WebRTC status + 'webrtc_connected': False, + 'webrtc_peer_id': None, + + # Processing flags + 'is_processing': False, + 'pending_user_audio': None } - emit('ready', {'message': 'Connection established'}) + + # Send config to client + emit('session_ready', { + 'whisper_available': whisper_model is not None, + 'csm_available': csm_generator is not None, + 'llm_available': llm_model is not None, + 'client_sample_rate': CLIENT_SAMPLE_RATE, + 'server_sample_rate': getattr(csm_generator, 'sample_rate', 24000) if csm_generator else 24000, + 'ice_servers': ICE_SERVERS + }) @socketio.on('disconnect') def handle_disconnect(): - print(f"Client disconnected: {request.sid}") - if request.sid in conversation_context: - del conversation_context[request.sid] + """Handle client disconnection""" + session_id = request.sid + print(f"Client disconnected: {session_id}") + + # Clean up resources + if session_id in user_sessions: + # Signal any running threads to stop + user_sessions[session_id]['should_interrupt_ai'] = True + + # Clean up resources + del user_sessions[session_id] -@socketio.on('start_speaking') -def handle_start_speaking(): - if request.sid in conversation_context: - conversation_context[request.sid]['is_speaking'] = True - conversation_context[request.sid]['audio_buffer'].clear() - print(f"User {request.sid} started speaking") - -@socketio.on('audio_chunk') -def handle_audio_chunk(data): - if request.sid not in conversation_context: +@socketio.on('webrtc_signal') +def handle_webrtc_signal(data): + """Handle WebRTC signaling for P2P connection establishment""" + session_id = request.sid + if session_id not in user_sessions: return - context = conversation_context[request.sid] - - # Decode audio data - audio_data = base64.b64decode(data['audio']) - audio_numpy = np.frombuffer(audio_data, dtype=np.float32) - audio_tensor = torch.tensor(audio_numpy) - - # Add to buffer - context['audio_buffer'].append(audio_tensor) - - # Check for silence to detect end of speech - if context['is_speaking'] and is_silence(audio_tensor): - if context['silence_start'] is None: - context['silence_start'] = time.time() - elif time.time() - context['silence_start'] > 1.0: # 1 second of silence - # Process the complete utterance - process_user_utterance(request.sid) - else: - context['silence_start'] = None + # Simply relay the signal to the client + # In a multi-user app, we would route this to the correct peer + emit('webrtc_signal', data) -@socketio.on('stop_speaking') -def handle_stop_speaking(): - if request.sid in conversation_context: - conversation_context[request.sid]['is_speaking'] = False - process_user_utterance(request.sid) - print(f"User {request.sid} stopped speaking") - -def is_silence(audio_tensor, threshold=0.02): - """Check if an audio chunk is silence based on amplitude threshold""" - return torch.mean(torch.abs(audio_tensor)) < threshold - -def process_user_utterance(session_id): - """Process completed user utterance, generate response and send audio back""" - context = conversation_context[session_id] - - if not context['audio_buffer']: +@socketio.on('webrtc_connected') +def handle_webrtc_connected(data): + """Client notifies that WebRTC connection is established""" + session_id = request.sid + if session_id not in user_sessions: return - # Combine audio chunks - full_audio = torch.cat(list(context['audio_buffer']), dim=0) - context['audio_buffer'].clear() - context['is_speaking'] = False - context['silence_start'] = None + user_sessions[session_id]['webrtc_connected'] = True + print(f"WebRTC connected for session {session_id}") + emit('ready_for_speech', {'message': 'Ready to start conversation'}) + +@socketio.on('audio_stream') +def handle_audio_stream(data): + """Process incoming audio stream packets from client""" + session_id = request.sid + if session_id not in user_sessions: + return - # Save audio to temporary WAV file for transcription - temp_audio_path = f"temp_audio_{session_id}.wav" - torchaudio.save( - temp_audio_path, - full_audio.unsqueeze(0), - 44100 # Assuming 44.1kHz from client - ) + session = user_sessions[session_id] try: - # Try using Whisper first if available - if whisper_model is not None: - user_text = transcribe_with_whisper(temp_audio_path) - else: - # Fallback to Google's speech recognition - user_text = transcribe_with_google(temp_audio_path) - - if not user_text: - print("No speech detected.") - emit('error', {'message': 'No speech detected. Please try again.'}, room=session_id) + # Decode audio data + audio_bytes = base64.b64decode(data.get('audio', '')) + if not audio_bytes or len(audio_bytes) < 2: # Need at least one sample return + + # Add to current audio buffer + session['audio_buffer'] += audio_bytes + + # Check for speech using VAD + has_speech = detect_speech(audio_bytes, session_id) + + # Handle speech state machine + if has_speech: + # Reset silence tracking when speech is detected + session['last_vad_active'] = time.time() + session['silence_duration'] = 0 + session['speech_frames'] += 1 + # If not already marked as speaking and we have enough speech frames + if not session['is_user_speaking'] and session['speech_frames'] >= 5: + on_speech_started(session_id) + else: + # No speech detected in this frame + if session['is_user_speaking']: + # Calculate silence duration + now = time.time() + session['silence_duration'] = now - session['last_vad_active'] + + # If silent for more than 0.5 seconds, end speech segment + if session['silence_duration'] > 0.8 and session['speech_frames'] > 8: + on_speech_ended(session_id) + else: + # Not speaking and no speech, just a silent frame + session['speech_frames'] = max(0, session['speech_frames'] - 1) + + except Exception as e: + print(f"Error processing audio stream: {e}") + +def detect_speech(audio_bytes, session_id): + """Use VAD to check if audio contains speech""" + if session_id not in user_sessions: + return False + + session = user_sessions[session_id] + + # Store in VAD buffer for history + session['vad_buffer'].append(audio_bytes) + + if vad is None: + # Fallback to simple energy detection + audio_data = np.frombuffer(audio_bytes, dtype=np.int16) + energy = np.mean(np.abs(audio_data)) / 32768.0 + return energy > 0.015 # Simple threshold + + try: + # Ensure we have the right amount of data for VAD + audio_data = np.frombuffer(audio_bytes, dtype=np.int16) + + # If we have too much data, use just the right amount + if len(audio_data) >= VAD_FRAME_SIZE: + frame = audio_data[:VAD_FRAME_SIZE].tobytes() + return vad.is_speech(frame, SAMPLE_RATE) + + # If too little data, accumulate in the VAD buffer and check periodically + if len(session['vad_buffer']) >= 3: + # Combine recent chunks to get enough data + combined = bytearray() + for chunk in list(session['vad_buffer'])[-3:]: + combined.extend(chunk) + + # Extract the right amount of data + if len(combined) >= VAD_FRAME_SIZE: + frame = combined[:VAD_FRAME_SIZE] + return vad.is_speech(bytes(frame), SAMPLE_RATE) + + return False + + except Exception as e: + print(f"VAD error: {e}") + return False + +def on_speech_started(session_id): + """Handle start of user speech""" + if session_id not in user_sessions: + return + + session = user_sessions[session_id] + + # Reset audio buffer + session['audio_buffer'] = bytearray() + session['is_user_speaking'] = True + session['is_turn_active'] = True + + # If AI is speaking, we need to interrupt it + if session['is_ai_speaking']: + session['should_interrupt_ai'] = True + emit('ai_interrupted_by_user', room=session_id) + + # Notify client that we detected speech + emit('user_speech_start', room=session_id) + +def on_speech_ended(session_id): + """Handle end of user speech segment""" + if session_id not in user_sessions: + return + + session = user_sessions[session_id] + + # Mark as not speaking anymore + session['is_user_speaking'] = False + session['speech_frames'] = 0 + + # If no audio or already processing, skip + if len(session['audio_buffer']) < 4000 or session['is_processing']: # At least 250ms of audio + session['audio_buffer'] = bytearray() + return + + # Mark as processing to prevent multiple processes + session['is_processing'] = True + + # Create a copy of the audio buffer + audio_copy = session['audio_buffer'] + session['audio_buffer'] = bytearray() + + # Convert audio to the format needed for processing + try: + # Convert to float32 between -1 and 1 + audio_np = np.frombuffer(audio_copy, dtype=np.int16).astype(np.float32) / 32768.0 + audio_tensor = torch.from_numpy(audio_np) + + # Resample to Whisper's expected sample rate if necessary + if CLIENT_SAMPLE_RATE != WHISPER_SAMPLE_RATE: + audio_tensor = torchaudio.functional.resample( + audio_tensor, + orig_freq=CLIENT_SAMPLE_RATE, + new_freq=WHISPER_SAMPLE_RATE + ) + + # Save as WAV for transcription + temp_audio_path = f"temp_audio_{session_id}.wav" + torchaudio.save( + temp_audio_path, + audio_tensor.unsqueeze(0), + WHISPER_SAMPLE_RATE + ) + + # Start transcription and response process in a thread + threading.Thread( + target=process_user_utterance, + args=(session_id, temp_audio_path, audio_tensor), + daemon=True + ).start() + + # Notify client that processing has started + emit('processing_speech', room=session_id) + + except Exception as e: + print(f"Error preparing audio: {e}") + session['is_processing'] = False + emit('error', {'message': f'Error processing audio: {str(e)}'}, room=session_id) + +def process_user_utterance(session_id, audio_path, audio_tensor): + """Process user utterance, transcribe and generate response""" + if session_id not in user_sessions: + return + + session = user_sessions[session_id] + + try: + # Transcribe audio + if whisper_model is not None: + user_text = transcribe_with_whisper(audio_path) + else: + # Fallback to another transcription service + user_text = transcribe_fallback(audio_path) + + # Clean up temp file + if os.path.exists(audio_path): + os.remove(audio_path) + + # Check if we got meaningful text + if not user_text or len(user_text.strip()) < 2: + emit('no_speech_detected', room=session_id) + session['is_processing'] = False + return + print(f"Transcribed: {user_text}") - # Add to conversation segments + # Create user segment user_segment = Segment( text=user_text, speaker=0, # User is speaker 0 - audio=full_audio + audio=audio_tensor ) - context['segments'].append(user_segment) + session['segments'].append(user_segment) - # Generate bot response - bot_response = generate_llm_response(user_text, context['segments']) - print(f"Bot response: {bot_response}") + # Update conversation history + session['conversation_history'].append({ + 'role': 'user', + 'text': user_text + }) - # Send transcribed text to client + # Send transcription to client emit('transcription', {'text': user_text}, room=session_id) - # Generate and send audio response if CSM is available - if csm_generator is not None: - # Convert to audio using CSM - bot_audio = generate_audio_response(bot_response, context['segments']) - - # Convert audio to base64 for sending over websocket - audio_bytes = io.BytesIO() - torchaudio.save(audio_bytes, bot_audio.unsqueeze(0).cpu(), csm_generator.sample_rate, format="wav") - audio_bytes.seek(0) - audio_b64 = base64.b64encode(audio_bytes.read()).decode('utf-8') - - # Add bot response to conversation history - bot_segment = Segment( - text=bot_response, - speaker=1, # Bot is speaker 1 - audio=bot_audio - ) - context['segments'].append(bot_segment) - - # Send audio response to client - emit('audio_response', { - 'audio': audio_b64, - 'text': bot_response - }, room=session_id) - else: - # Send text-only response if audio generation isn't available - emit('text_response', {'text': bot_response}, room=session_id) - - # Add text-only bot response to conversation history - bot_segment = Segment( - text=bot_response, - speaker=1, # Bot is speaker 1 - audio=torch.zeros(1) # Placeholder empty audio - ) - context['segments'].append(bot_segment) + # Generate AI response + ai_response = generate_ai_response(user_text, session_id) + # Send text response to client + emit('ai_response_text', {'text': ai_response}, room=session_id) + + # Update conversation history + session['conversation_history'].append({ + 'role': 'assistant', + 'text': ai_response + }) + + # Generate voice response if CSM is available + if csm_generator is not None: + session['is_ai_speaking'] = True + session['should_interrupt_ai'] = False + + # Begin streaming audio response + threading.Thread( + target=stream_ai_response, + args=(ai_response, session_id), + daemon=True + ).start() + except Exception as e: - print(f"Error processing speech: {e}") - emit('error', {'message': f'Error processing speech: {str(e)}'}, room=session_id) + print(f"Error processing utterance: {e}") + emit('error', {'message': f'Error: {str(e)}'}, room=session_id) + finally: - # Cleanup temp file - if os.path.exists(temp_audio_path): - os.remove(temp_audio_path) + # Clear processing flag + if session_id in user_sessions: + session['is_processing'] = False def transcribe_with_whisper(audio_path): """Transcribe audio using Faster-Whisper""" @@ -250,47 +490,58 @@ def transcribe_with_whisper(audio_path): # Collect all text from segments user_text = "" for segment in segments: - segment_text = segment.text.strip() - print(f"[{segment.start:.2f}s -> {segment.end:.2f}s] {segment_text}") - user_text += segment_text + " " + user_text += segment.text.strip() + " " return user_text.strip() -def transcribe_with_google(audio_path): +def transcribe_fallback(audio_path): """Fallback transcription using Google's speech recognition""" - import speech_recognition as sr - recognizer = sr.Recognizer() - - with sr.AudioFile(audio_path) as source: - audio = recognizer.record(source) - try: - text = recognizer.recognize_google(audio) - return text - except sr.UnknownValueError: - return "" - except sr.RequestError: - # If Google API fails, try a basic energy-based VAD approach - # This is a very basic fallback and won't give good results - return "[Speech detected but transcription failed]" + try: + import speech_recognition as sr + recognizer = sr.Recognizer() + + with sr.AudioFile(audio_path) as source: + audio = recognizer.record(source) + try: + text = recognizer.recognize_google(audio) + return text + except sr.UnknownValueError: + return "" + except sr.RequestError: + return "[Speech recognition service unavailable]" + except ImportError: + return "[Speech recognition not available]" -def generate_llm_response(user_text, conversation_segments): - """Generate text response using available model""" +def generate_ai_response(user_text, session_id): + """Generate text response using available LLM""" + if session_id not in user_sessions: + return "I'm sorry, your session has expired." + + session = user_sessions[session_id] + if llm_model is not None and llm_tokenizer is not None: # Format conversation history for the LLM - conversation_history = "" - for segment in conversation_segments[-5:]: # Use last 5 utterances for context - speaker_name = "User" if segment.speaker == 0 else "Assistant" - conversation_history += f"{speaker_name}: {segment.text}\n" + prompt = "You are a helpful, friendly voice assistant. Keep your responses brief and conversational.\n\n" - # Add the current user query - conversation_history += f"User: {user_text}\nAssistant:" + # Add recent conversation history (last 6 turns maximum) + for entry in session['conversation_history'][-6:]: + if entry['role'] == 'user': + prompt += f"User: {entry['text']}\n" + else: + prompt += f"Assistant: {entry['text']}\n" + + # Add current query if not already in history + if not session['conversation_history'] or session['conversation_history'][-1]['role'] != 'user': + prompt += f"User: {user_text}\n" + + prompt += "Assistant: " try: # Generate response - inputs = llm_tokenizer(conversation_history, return_tensors="pt").to(device) + inputs = llm_tokenizer(prompt, return_tensors="pt").to(device) output = llm_model.generate( inputs.input_ids, - max_new_tokens=150, + max_new_tokens=100, # Keep responses shorter for voice temperature=0.7, top_p=0.9, do_sample=True @@ -298,40 +549,48 @@ def generate_llm_response(user_text, conversation_segments): response = llm_tokenizer.decode(output[0][inputs.input_ids.shape[1]:], skip_special_tokens=True) return response.strip() + except Exception as e: - print(f"Error generating response with LLM: {e}") + print(f"Error generating LLM response: {e}") return fallback_response(user_text) else: return fallback_response(user_text) def fallback_response(user_text): - """Generate a simple fallback response when LLM is not available""" - # Simple rule-based responses + """Generate simple fallback responses when LLM is unavailable""" user_text_lower = user_text.lower() if "hello" in user_text_lower or "hi" in user_text_lower: - return "Hello! I'm a simple fallback assistant. The main language model couldn't be loaded, so I have limited capabilities." + return "Hello! How can I help you today?" elif "how are you" in user_text_lower: - return "I'm functioning within my limited capabilities. How can I assist you today?" + return "I'm doing well, thanks for asking! How about you?" elif "thank" in user_text_lower: - return "You're welcome! Let me know if there's anything else I can help with." + return "You're welcome! Happy to help." elif "bye" in user_text_lower or "goodbye" in user_text_lower: return "Goodbye! Have a great day!" elif any(q in user_text_lower for q in ["what", "who", "where", "when", "why", "how"]): - return "I'm running in fallback mode and can't answer complex questions. Please try again when the main language model is available." + return "That's an interesting question. I wish I could provide a better answer in my current fallback mode." else: - return "I understand you said something about that. Unfortunately, I'm running in fallback mode with limited capabilities. Please try again later when the main model is available." + return "I see. Tell me more about that." -def generate_audio_response(text, conversation_segments): - """Generate audio response using CSM""" +def stream_ai_response(text, session_id): + """Generate and stream audio response in real-time chunks""" + if session_id not in user_sessions: + return + + session = user_sessions[session_id] + try: - # Use the last few conversation segments as context - context_segments = conversation_segments[-4:] if len(conversation_segments) > 4 else conversation_segments + # Signal start of AI speech + emit('ai_speech_start', room=session_id) + + # Use the last few conversation segments as context (up to 4) + context_segments = session['segments'][-4:] if len(session['segments']) > 4 else session['segments'] # Generate audio for bot response audio = csm_generator.generate( @@ -343,11 +602,77 @@ def generate_audio_response(text, conversation_segments): topk=50 ) - return audio + # Create and store bot segment + bot_segment = Segment( + text=text, + speaker=1, + audio=audio + ) + + if session_id in user_sessions: + session['segments'].append(bot_segment) + + # Stream audio in small chunks for more responsive playback + chunk_size = AUDIO_CHUNK_SIZE # Size defined in constants + + for i in range(0, len(audio), chunk_size): + # Check if we should stop (user interrupted) + if session_id not in user_sessions or session['should_interrupt_ai']: + print("AI speech interrupted") + break + + # Get next chunk + chunk = audio[i:i+chunk_size] + + # Convert audio chunk to base64 for streaming + audio_bytes = io.BytesIO() + torchaudio.save(audio_bytes, chunk.unsqueeze(0).cpu(), csm_generator.sample_rate, format="wav") + audio_bytes.seek(0) + audio_b64 = base64.b64encode(audio_bytes.read()).decode('utf-8') + + # Send chunk to client + socketio.emit('ai_speech_chunk', { + 'audio': audio_b64, + 'is_last': i + chunk_size >= len(audio) + }, room=session_id) + + # Small sleep for more natural pacing + time.sleep(0.06) # Slight delay for smoother playback + + # Signal end of AI speech + if session_id in user_sessions: + session['is_ai_speaking'] = False + session['is_turn_active'] = False # End conversation turn + socketio.emit('ai_speech_end', room=session_id) + except Exception as e: - print(f"Error generating audio: {e}") - # Return silence as fallback - return torch.zeros(csm_generator.sample_rate * 3) # 3 seconds of silence + print(f"Error streaming AI response: {e}") + if session_id in user_sessions: + session['is_ai_speaking'] = False + session['is_turn_active'] = False + socketio.emit('error', {'message': f'Error generating audio: {str(e)}'}, room=session_id) + socketio.emit('ai_speech_end', room=session_id) + +@socketio.on('interrupt_ai') +def handle_interrupt(): + """Handle explicit AI interruption request from client""" + session_id = request.sid + if session_id in user_sessions: + user_sessions[session_id]['should_interrupt_ai'] = True + emit('ai_interrupted', room=session_id) + +@socketio.on('get_config') +def handle_get_config(): + """Send configuration to client""" + session_id = request.sid + if session_id in user_sessions: + emit('config', { + 'client_sample_rate': CLIENT_SAMPLE_RATE, + 'server_sample_rate': getattr(csm_generator, 'sample_rate', 24000) if csm_generator else 24000, + 'whisper_available': whisper_model is not None, + 'csm_available': csm_generator is not None, + 'ice_servers': ICE_SERVERS + }) if __name__ == '__main__': # Ensure the existing index.html file is in the correct location @@ -357,9 +682,8 @@ if __name__ == '__main__': if os.path.exists('index.html') and not os.path.exists('templates/index.html'): os.rename('index.html', 'templates/index.html') - # Load models asynchronously before starting the server - print("Starting CPU-only model loading...") - # In a production environment, you could load models in a separate thread + # Load models before starting the server + print("Starting model loading...") load_models() # Start the server diff --git a/Backend/voice-chat.js b/Backend/voice-chat.js new file mode 100644 index 0000000..12bac9a --- /dev/null +++ b/Backend/voice-chat.js @@ -0,0 +1,560 @@ +document.addEventListener('DOMContentLoaded', () => { + // DOM Elements + const startButton = document.getElementById('start-button'); + const interruptButton = document.getElementById('interrupt-button'); + const conversationDiv = document.getElementById('conversation'); + const connectionDot = document.getElementById('connection-dot'); + const connectionStatus = document.getElementById('connection-status'); + const whisperStatus = document.getElementById('whisper-status'); + const csmStatus = document.getElementById('csm-status'); + const llmStatus = document.getElementById('llm-status'); + const webrtcStatus = document.getElementById('webrtc-status'); + const micAnimation = document.getElementById('mic-animation'); + const loadingDiv = document.getElementById('loading'); + const loadingText = document.getElementById('loading-text'); + + // State variables + let socket; + let isConnected = false; + let isListening = false; + let isAiSpeaking = false; + let audioContext; + let mediaStream; + let audioRecorder; + let audioProcessor; + const audioChunks = []; + + // WebRTC variables + let peerConnection; + let dataChannel; + let hasActiveConnection = false; + + // Audio playback + let audioQueue = []; + let isPlaying = false; + + // Configuration variables + let serverSampleRate = 24000; + let clientSampleRate = 44100; + let iceServers = []; + + // Initialize the application + initApp(); + + // Main initialization function + function initApp() { + updateConnectionStatus('connecting'); + setupSocketConnection(); + setupEventListeners(); + } + + // Set up Socket.IO connection with server + function setupSocketConnection() { + socket = io(); + + socket.on('connect', () => { + console.log('Connected to server'); + updateConnectionStatus('connected'); + isConnected = true; + }); + + socket.on('disconnect', () => { + console.log('Disconnected from server'); + updateConnectionStatus('disconnected'); + isConnected = false; + cleanupAudio(); + cleanupWebRTC(); + }); + + socket.on('session_ready', (data) => { + console.log('Session ready:', data); + updateModelStatus(data); + clientSampleRate = data.client_sample_rate; + serverSampleRate = data.server_sample_rate; + iceServers = data.ice_servers; + + // Initialize WebRTC if models are available + if (data.whisper_available && data.llm_available) { + initializeWebRTC(); + } + }); + + socket.on('ready_for_speech', (data) => { + console.log('Ready for speech:', data); + startButton.disabled = false; + addInfoMessage('Ready for conversation. Click "Start Listening" to begin.'); + }); + + socket.on('webrtc_signal', (data) => { + handleWebRTCSignal(data); + }); + + socket.on('transcription', (data) => { + console.log('Transcription:', data); + addUserMessage(data.text); + loadingDiv.style.display = 'none'; + }); + + socket.on('ai_response_text', (data) => { + console.log('AI response text:', data); + addAIMessage(data.text); + loadingDiv.style.display = 'none'; + }); + + socket.on('ai_speech_start', () => { + console.log('AI started speaking'); + isAiSpeaking = true; + interruptButton.disabled = false; + }); + + socket.on('ai_speech_chunk', (data) => { + console.log('Received AI speech chunk'); + playAudioChunk(data.audio, data.is_last); + }); + + socket.on('ai_speech_end', () => { + console.log('AI stopped speaking'); + isAiSpeaking = false; + interruptButton.disabled = true; + }); + + socket.on('user_speech_start', () => { + console.log('User speech detected'); + showSpeakingIndicator(true); + }); + + socket.on('processing_speech', () => { + console.log('Processing speech'); + showSpeakingIndicator(false); + showLoadingIndicator('Processing your speech...'); + }); + + socket.on('no_speech_detected', () => { + console.log('No speech detected'); + hideLoadingIndicator(); + addInfoMessage('No speech detected. Please try again.'); + }); + + socket.on('ai_interrupted', () => { + console.log('AI interrupted'); + clearAudioQueue(); + isAiSpeaking = false; + interruptButton.disabled = true; + }); + + socket.on('ai_interrupted_by_user', () => { + console.log('AI interrupted by user'); + clearAudioQueue(); + isAiSpeaking = false; + interruptButton.disabled = true; + addInfoMessage('AI interrupted by your speech'); + }); + + socket.on('error', (data) => { + console.error('Server error:', data); + hideLoadingIndicator(); + addInfoMessage(`Error: ${data.message}`); + }); + } + + // Set up UI event listeners + function setupEventListeners() { + startButton.addEventListener('click', toggleListening); + interruptButton.addEventListener('click', interruptAI); + } + + // Update UI connection status + function updateConnectionStatus(status) { + connectionDot.className = 'status-dot ' + status; + + switch (status) { + case 'connected': + connectionStatus.textContent = 'Connected'; + break; + case 'connecting': + connectionStatus.textContent = 'Connecting...'; + break; + case 'disconnected': + connectionStatus.textContent = 'Disconnected'; + startButton.disabled = true; + interruptButton.disabled = true; + break; + } + } + + // Update model status indicators + function updateModelStatus(data) { + whisperStatus.textContent = data.whisper_available ? 'Available' : 'Not Available'; + whisperStatus.style.color = data.whisper_available ? 'green' : 'red'; + + csmStatus.textContent = data.csm_available ? 'Available' : 'Not Available'; + csmStatus.style.color = data.csm_available ? 'green' : 'red'; + + llmStatus.textContent = data.llm_available ? 'Available' : 'Not Available'; + llmStatus.style.color = data.llm_available ? 'green' : 'red'; + } + + // Initialize WebRTC connection + function initializeWebRTC() { + if (!isConnected) return; + + const configuration = { + iceServers: iceServers + }; + + peerConnection = new RTCPeerConnection(configuration); + + // Create data channel for WebRTC communication + dataChannel = peerConnection.createDataChannel('audioData', { + ordered: true + }); + + dataChannel.onopen = () => { + console.log('WebRTC data channel open'); + hasActiveConnection = true; + webrtcStatus.textContent = 'Connected'; + webrtcStatus.style.color = 'green'; + socket.emit('webrtc_connected', { status: 'connected' }); + }; + + dataChannel.onclose = () => { + console.log('WebRTC data channel closed'); + hasActiveConnection = false; + webrtcStatus.textContent = 'Disconnected'; + webrtcStatus.style.color = 'red'; + }; + + // Handle ICE candidates + peerConnection.onicecandidate = (event) => { + if (event.candidate) { + socket.emit('webrtc_signal', { + type: 'ice_candidate', + candidate: event.candidate + }); + } + }; + + // Log ICE connection state changes + peerConnection.oniceconnectionstatechange = () => { + console.log('ICE connection state:', peerConnection.iceConnectionState); + }; + + // Create offer + peerConnection.createOffer() + .then(offer => peerConnection.setLocalDescription(offer)) + .then(() => { + socket.emit('webrtc_signal', { + type: 'offer', + sdp: peerConnection.localDescription + }); + }) + .catch(error => { + console.error('Error creating WebRTC offer:', error); + webrtcStatus.textContent = 'Failed to Connect'; + webrtcStatus.style.color = 'red'; + }); + } + + // Handle WebRTC signals from the server + function handleWebRTCSignal(data) { + if (!peerConnection) return; + + if (data.type === 'answer') { + peerConnection.setRemoteDescription(new RTCSessionDescription(data.sdp)) + .catch(error => console.error('Error setting remote description:', error)); + } + else if (data.type === 'ice_candidate') { + peerConnection.addIceCandidate(new RTCIceCandidate(data.candidate)) + .catch(error => console.error('Error adding ICE candidate:', error)); + } + } + + // Clean up WebRTC connection + function cleanupWebRTC() { + if (dataChannel) { + dataChannel.close(); + } + + if (peerConnection) { + peerConnection.close(); + } + + dataChannel = null; + peerConnection = null; + hasActiveConnection = false; + webrtcStatus.textContent = 'Not Connected'; + webrtcStatus.style.color = 'red'; + } + + // Toggle audio listening + function toggleListening() { + if (isListening) { + stopListening(); + } else { + startListening(); + } + } + + // Start listening for audio + async function startListening() { + if (!isConnected) return; + + try { + await initAudio(); + isListening = true; + startButton.textContent = 'Stop Listening'; + startButton.innerHTML = ` + + + + Stop Listening + `; + } catch (error) { + console.error('Error starting audio:', error); + addInfoMessage('Error accessing microphone. Please check permissions.'); + } + } + + // Stop listening for audio + function stopListening() { + cleanupAudio(); + isListening = false; + startButton.innerHTML = ` + + + + Start Listening + `; + showSpeakingIndicator(false); + } + + // Initialize audio capture + async function initAudio() { + // Request microphone access + mediaStream = await navigator.mediaDevices.getUserMedia({ + audio: { + sampleRate: clientSampleRate, + channelCount: 1, + echoCancellation: true, + noiseSuppression: true, + autoGainControl: true + } + }); + + // Initialize AudioContext + audioContext = new (window.AudioContext || window.webkitAudioContext)({ + sampleRate: clientSampleRate + }); + + // Create audio source from stream + const source = audioContext.createMediaStreamSource(mediaStream); + + // Create ScriptProcessor for audio processing + const bufferSize = 4096; + audioProcessor = audioContext.createScriptProcessor(bufferSize, 1, 1); + + // Process audio data + audioProcessor.onaudioprocess = (event) => { + if (!isListening || isAiSpeaking) return; + + const input = event.inputBuffer.getChannelData(0); + const audioData = convertFloat32ToInt16(input); + sendAudioChunk(audioData); + }; + + // Connect the nodes + source.connect(audioProcessor); + audioProcessor.connect(audioContext.destination); + } + + // Clean up audio resources + function cleanupAudio() { + if (audioProcessor) { + audioProcessor.disconnect(); + audioProcessor = null; + } + + if (mediaStream) { + mediaStream.getTracks().forEach(track => track.stop()); + mediaStream = null; + } + + if (audioContext && audioContext.state !== 'closed') { + audioContext.close().catch(error => console.error('Error closing AudioContext:', error)); + } + + audioChunks.length = 0; + } + + // Convert Float32Array to Int16Array for sending to server + function convertFloat32ToInt16(float32Array) { + const int16Array = new Int16Array(float32Array.length); + for (let i = 0; i < float32Array.length; i++) { + // Convert float [-1.0, 1.0] to int16 [-32768, 32767] + int16Array[i] = Math.max(-32768, Math.min(32767, Math.floor(float32Array[i] * 32768))); + } + return int16Array; + } + + // Send audio chunk to server + function sendAudioChunk(audioData) { + if (!isConnected || !isListening) return; + + // Convert to base64 for transmission + const base64Audio = arrayBufferToBase64(audioData.buffer); + + // Send via Socket.IO (could use WebRTC's DataChannel for lower latency in production) + socket.emit('audio_stream', { audio: base64Audio }); + } + + // Play audio chunk received from server + function playAudioChunk(base64Audio, isLast) { + const audioData = base64ToArrayBuffer(base64Audio); + + // Add to queue + audioQueue.push({ + data: audioData, + isLast: isLast + }); + + // Start playing if not already playing + if (!isPlaying) { + playNextAudioChunk(); + } + } + + // Play the next audio chunk in the queue + function playNextAudioChunk() { + if (audioQueue.length === 0) { + isPlaying = false; + return; + } + + isPlaying = true; + const chunk = audioQueue.shift(); + + try { + // Create audio context if needed + if (!audioContext || audioContext.state === 'closed') { + audioContext = new (window.AudioContext || window.webkitAudioContext)(); + } + + // Resume audio context if suspended + if (audioContext.state === 'suspended') { + audioContext.resume(); + } + + // Decode the WAV data + audioContext.decodeAudioData(chunk.data, (buffer) => { + const source = audioContext.createBufferSource(); + source.buffer = buffer; + source.connect(audioContext.destination); + + // When playback ends, play the next chunk + source.onended = () => { + playNextAudioChunk(); + }; + + source.start(0); + + // If it's the last chunk, update UI + if (chunk.isLast) { + setTimeout(() => { + isAiSpeaking = false; + interruptButton.disabled = true; + }, buffer.duration * 1000); + } + }, (error) => { + console.error('Error decoding audio data:', error); + playNextAudioChunk(); // Skip this chunk and try the next + }); + } catch (error) { + console.error('Error playing audio chunk:', error); + playNextAudioChunk(); // Try the next chunk + } + } + + // Clear the audio queue (used when interrupting) + function clearAudioQueue() { + audioQueue.length = 0; + isPlaying = false; + + // Stop any currently playing audio + if (audioContext) { + audioContext.suspend(); + } + } + + // Send interrupt signal to server + function interruptAI() { + if (!isConnected || !isAiSpeaking) return; + + socket.emit('interrupt_ai'); + clearAudioQueue(); + } + + // Convert ArrayBuffer to Base64 string + function arrayBufferToBase64(buffer) { + const binary = new Uint8Array(buffer); + let base64 = ''; + const len = binary.byteLength; + for (let i = 0; i < len; i++) { + base64 += String.fromCharCode(binary[i]); + } + return window.btoa(base64); + } + + // Convert Base64 string to ArrayBuffer + function base64ToArrayBuffer(base64) { + const binaryString = window.atob(base64); + const len = binaryString.length; + const bytes = new Uint8Array(len); + for (let i = 0; i < len; i++) { + bytes[i] = binaryString.charCodeAt(i); + } + return bytes.buffer; + } + + // Add user message to conversation + function addUserMessage(text) { + const messageDiv = document.createElement('div'); + messageDiv.className = 'message user-message'; + messageDiv.textContent = text; + conversationDiv.appendChild(messageDiv); + conversationDiv.scrollTop = conversationDiv.scrollHeight; + } + + // Add AI message to conversation + function addAIMessage(text) { + const messageDiv = document.createElement('div'); + messageDiv.className = 'message ai-message'; + messageDiv.textContent = text; + conversationDiv.appendChild(messageDiv); + conversationDiv.scrollTop = conversationDiv.scrollHeight; + } + + // Add info message to conversation + function addInfoMessage(text) { + const messageDiv = document.createElement('div'); + messageDiv.className = 'info-message'; + messageDiv.textContent = text; + conversationDiv.appendChild(messageDiv); + conversationDiv.scrollTop = conversationDiv.scrollHeight; + } + + // Show/hide speaking indicator + function showSpeakingIndicator(show) { + micAnimation.style.display = show ? 'flex' : 'none'; + } + + // Show loading indicator + function showLoadingIndicator(text) { + loadingText.textContent = text || 'Processing...'; + loadingDiv.style.display = 'block'; + } + + // Hide loading indicator + function hideLoadingIndicator() { + loadingDiv.style.display = 'none'; + } +}); \ No newline at end of file diff --git a/React/public/icon-128x128.png b/React/public/icon-128x128.png new file mode 100644 index 0000000..a0ffe32 Binary files /dev/null and b/React/public/icon-128x128.png differ diff --git a/React/public/icon-512x512.png b/React/public/icon-512x512.png new file mode 100644 index 0000000..cb560a1 Binary files /dev/null and b/React/public/icon-512x512.png differ diff --git a/React/src/app/layout.tsx b/React/src/app/layout.tsx index 43900a8..8dccd6b 100644 --- a/React/src/app/layout.tsx +++ b/React/src/app/layout.tsx @@ -13,8 +13,8 @@ const geistMono = Geist_Mono({ }); export const metadata: Metadata = { - title: "Create Next App", - description: "Generated by create next app", + title: "Fauxcall", + description: "Fauxcall is a fake call app that helps you get out of awkward situations.", }; export default function RootLayout({ diff --git a/React/src/app/manifest.ts b/React/src/app/manifest.ts new file mode 100644 index 0000000..1727948 --- /dev/null +++ b/React/src/app/manifest.ts @@ -0,0 +1,25 @@ +import type { MetadataRoute } from 'next' + +export default function manifest(): MetadataRoute.Manifest { + return { + name: 'Fauxcall', + short_name: 'Fauxcall', + description: 'A fake call app that helps you get out of awkward and dangerous situations.', + start_url: '/', + display: 'standalone', + background_color: '#ffffff', + theme_color: '#000000', + icons: [ + { + src: '/icon-192x192.png', + sizes: '192x192', + type: 'image/png', + }, + { + src: '/icon-512x512.png', + sizes: '512x512', + type: 'image/png', + }, + ], + } +} \ No newline at end of file diff --git a/React/src/app/page.tsx b/React/src/app/page.tsx index f5d7ff9..a37023d 100644 --- a/React/src/app/page.tsx +++ b/React/src/app/page.tsx @@ -4,7 +4,7 @@ import { useRouter } from "next/navigation"; import './styles.css'; export default function Home() { - const [contacts, setContacts] = useState([]); + const [contacts, setContacts] = useState([""]); const [codeword, setCodeword] = useState(""); const [session, setSession] = useState(null); const [loading, setLoading] = useState(true); @@ -26,6 +26,16 @@ export default function Home() { }); }, []); + const handleInputChange = (index: number, value: string) => { + const updatedContacts = [...contacts]; + updatedContacts[index] = value; // Update the specific input value + setContacts(updatedContacts); + }; + + const addContactInput = () => { + setContacts([...contacts, ""]); // Add a new empty input + }; + function saveToDB() { alert("Saving contacts..."); const contactInputs = document.querySelectorAll( @@ -144,27 +154,20 @@ export default function Home() { className="space-y-5 flex flex-col gap-[32px] row-start-2 items-center sm:items-start" onSubmit={(e) => e.preventDefault()} > - setContacts(e.target.value.split(","))} - placeholder="Write down an emergency contact" - className="border border-gray-300 rounded-md p-2" - /> - - - + {contacts.map((contact, index) => ( + handleInputChange(index, e.target.value)} + placeholder={`Contact ${index + 1}`} + className="border border-gray-300 rounded-md p-2" + /> + ))}