From d4a7cf0e2fcdd02ae37fd0df4e598d670eb9294e Mon Sep 17 00:00:00 2001 From: GamerBoss101 Date: Sun, 30 Mar 2025 03:43:08 -0400 Subject: [PATCH] Frontend Fixed --- Backend/api/app.py | 136 ++++ Backend/{ => api}/generator.py | 0 Backend/{ => api}/models.py | 0 Backend/api/routes.py | 74 ++ Backend/api/socket_handlers.py | 392 ++++++++++ Backend/{ => api}/watermarking.py | 0 Backend/index.html | 419 ----------- Backend/requirements.txt | 13 - Backend/run_csm.py | 117 --- Backend/server.py | 646 +--------------- Backend/setup.py | 13 - Backend/voice-chat.js | 1054 --------------------------- React/src/app/auth/session/route.ts | 12 + React/src/app/page.tsx | 236 +++--- 14 files changed, 777 insertions(+), 2335 deletions(-) create mode 100644 Backend/api/app.py rename Backend/{ => api}/generator.py (100%) rename Backend/{ => api}/models.py (100%) create mode 100644 Backend/api/routes.py create mode 100644 Backend/api/socket_handlers.py rename Backend/{ => api}/watermarking.py (100%) delete mode 100644 Backend/index.html delete mode 100644 Backend/requirements.txt delete mode 100644 Backend/run_csm.py delete mode 100644 Backend/setup.py delete mode 100644 Backend/voice-chat.js create mode 100644 React/src/app/auth/session/route.ts diff --git a/Backend/api/app.py b/Backend/api/app.py new file mode 100644 index 0000000..018061f --- /dev/null +++ b/Backend/api/app.py @@ -0,0 +1,136 @@ +import os +import logging +import threading +from dataclasses import dataclass +from flask import Flask +from flask_socketio import SocketIO +from flask_cors import CORS + +# Configure logging +logging.basicConfig(level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) + +# Configure device +import torch +if torch.cuda.is_available(): + DEVICE = "cuda" +elif torch.backends.mps.is_available(): + DEVICE = "mps" +else: + DEVICE = "cpu" + +logger.info(f"Using device: {DEVICE}") + +# Initialize Flask app +app = Flask(__name__, static_folder='../', static_url_path='') +CORS(app) +socketio = SocketIO(app, cors_allowed_origins="*", ping_timeout=120) + +# Global variables for conversation state +active_conversations = {} +user_queues = {} +processing_threads = {} + +# Model storage +@dataclass +class AppModels: + generator = None + tokenizer = None + llm = None + whisperx_model = None + whisperx_align_model = None + whisperx_align_metadata = None + last_language = None + +models = AppModels() + +def load_models(): + """Load all required models""" + from generator import load_csm_1b + import whisperx + import gc + from transformers import AutoModelForCausalLM, AutoTokenizer + global models + + socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 0}) + + # CSM 1B loading + try: + socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 10, 'message': 'Loading CSM voice model'}) + models.generator = load_csm_1b(device=DEVICE) + logger.info("CSM 1B model loaded successfully") + socketio.emit('model_status', {'model': 'csm', 'status': 'loaded'}) + socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 33}) + if DEVICE == "cuda": + torch.cuda.empty_cache() + except Exception as e: + import traceback + error_details = traceback.format_exc() + logger.error(f"Error loading CSM 1B model: {str(e)}\n{error_details}") + socketio.emit('model_status', {'model': 'csm', 'status': 'error', 'message': str(e)}) + + # WhisperX loading + try: + socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 40, 'message': 'Loading speech recognition model'}) + # Use WhisperX for better transcription with timestamps + # Use compute_type based on device + compute_type = "float16" if DEVICE == "cuda" else "float32" + + # Load the WhisperX model (smaller model for faster processing) + models.whisperx_model = whisperx.load_model("small", DEVICE, compute_type=compute_type) + + logger.info("WhisperX model loaded successfully") + socketio.emit('model_status', {'model': 'asr', 'status': 'loaded'}) + socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 66}) + if DEVICE == "cuda": + torch.cuda.empty_cache() + except Exception as e: + import traceback + error_details = traceback.format_exc() + logger.error(f"Error loading WhisperX model: {str(e)}\n{error_details}") + socketio.emit('model_status', {'model': 'asr', 'status': 'error', 'message': str(e)}) + + # Llama loading + try: + socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 70, 'message': 'Loading language model'}) + models.llm = AutoModelForCausalLM.from_pretrained( + "meta-llama/Llama-3.2-1B", + device_map=DEVICE, + torch_dtype=torch.bfloat16 + ) + models.tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B") + + # Configure all special tokens + models.tokenizer.pad_token = models.tokenizer.eos_token + models.tokenizer.padding_side = "left" # For causal language modeling + + # Inform the model about the pad token + if hasattr(models.llm.config, "pad_token_id") and models.llm.config.pad_token_id is None: + models.llm.config.pad_token_id = models.tokenizer.pad_token_id + + logger.info("Llama 3.2 model loaded successfully") + socketio.emit('model_status', {'model': 'llm', 'status': 'loaded'}) + socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 100, 'message': 'All models loaded successfully'}) + socketio.emit('model_status', {'model': 'overall', 'status': 'loaded'}) + except Exception as e: + logger.error(f"Error loading Llama 3.2 model: {str(e)}") + socketio.emit('model_status', {'model': 'llm', 'status': 'error', 'message': str(e)}) + +# Load models in a background thread +threading.Thread(target=load_models, daemon=True).start() + +# Import routes and socket handlers +from api.routes import register_routes +from api.socket_handlers import register_handlers + +# Register routes and socket handlers +register_routes(app) +register_handlers(socketio, app, models, active_conversations, user_queues, processing_threads, DEVICE) + +# Run server if executed directly +if __name__ == '__main__': + port = int(os.environ.get('PORT', 5000)) + debug_mode = os.environ.get('DEBUG', 'False').lower() == 'true' + logger.info(f"Starting server on port {port} (debug={debug_mode})") + socketio.run(app, host='0.0.0.0', port=port, debug=debug_mode, allow_unsafe_werkzeug=True) \ No newline at end of file diff --git a/Backend/generator.py b/Backend/api/generator.py similarity index 100% rename from Backend/generator.py rename to Backend/api/generator.py diff --git a/Backend/models.py b/Backend/api/models.py similarity index 100% rename from Backend/models.py rename to Backend/api/models.py diff --git a/Backend/api/routes.py b/Backend/api/routes.py new file mode 100644 index 0000000..af1bfce --- /dev/null +++ b/Backend/api/routes.py @@ -0,0 +1,74 @@ +import os +import torch +import psutil +from flask import send_from_directory, jsonify, request + +def register_routes(app): + """Register HTTP routes for the application""" + + @app.route('/') + def index(): + """Serve the main application page""" + return send_from_directory(app.static_folder, 'index.html') + + @app.route('/voice-chat.js') + def serve_js(): + """Serve the JavaScript file""" + return send_from_directory(app.static_folder, 'voice-chat.js') + + @app.route('/api/status') + def system_status(): + """Return the system status""" + # Import here to avoid circular imports + from api.app import models, DEVICE + + return jsonify({ + "status": "ok", + "cuda_available": torch.cuda.is_available(), + "device": DEVICE, + "models": { + "generator": models.generator is not None, + "asr": models.whisperx_model is not None, + "llm": models.llm is not None + }, + "versions": { + "transformers": "4.49.0", # Replace with actual version + "torch": torch.__version__ + } + }) + + @app.route('/api/system_resources') + def system_resources(): + """Return system resource usage""" + # Import here to avoid circular imports + from api.app import active_conversations, DEVICE + + # Get CPU usage + cpu_percent = psutil.cpu_percent(interval=0.1) + + # Get memory usage + memory = psutil.virtual_memory() + memory_used_gb = memory.used / (1024 ** 3) + memory_total_gb = memory.total / (1024 ** 3) + memory_percent = memory.percent + + # Get GPU memory if available + gpu_memory = {} + if torch.cuda.is_available(): + for i in range(torch.cuda.device_count()): + gpu_memory[f"gpu_{i}"] = { + "allocated": torch.cuda.memory_allocated(i) / (1024 ** 3), + "reserved": torch.cuda.memory_reserved(i) / (1024 ** 3), + "max_allocated": torch.cuda.max_memory_allocated(i) / (1024 ** 3) + } + + return jsonify({ + "cpu_percent": cpu_percent, + "memory": { + "used_gb": memory_used_gb, + "total_gb": memory_total_gb, + "percent": memory_percent + }, + "gpu_memory": gpu_memory, + "active_sessions": len(active_conversations) + }) \ No newline at end of file diff --git a/Backend/api/socket_handlers.py b/Backend/api/socket_handlers.py new file mode 100644 index 0000000..20513e9 --- /dev/null +++ b/Backend/api/socket_handlers.py @@ -0,0 +1,392 @@ +import os +import io +import base64 +import time +import threading +import queue +import tempfile +import gc +import logging +import traceback +from typing import Dict, List, Optional + +import torch +import torchaudio +import numpy as np +from flask import request +from flask_socketio import emit + +# Import conversation model +from generator import Segment + +logger = logging.getLogger(__name__) + +# Conversation data structure +class Conversation: + def __init__(self, session_id): + self.session_id = session_id + self.segments: List[Segment] = [] + self.current_speaker = 0 + self.ai_speaker_id = 1 # Default AI speaker ID + self.last_activity = time.time() + self.is_processing = False + + def add_segment(self, text, speaker, audio): + segment = Segment(text=text, speaker=speaker, audio=audio) + self.segments.append(segment) + self.last_activity = time.time() + return segment + + def get_context(self, max_segments=10): + """Return the most recent segments for context""" + return self.segments[-max_segments:] if self.segments else [] + +def register_handlers(socketio, app, models, active_conversations, user_queues, processing_threads, DEVICE): + """Register Socket.IO event handlers""" + + @socketio.on('connect') + def handle_connect(auth=None): + """Handle client connection""" + session_id = request.sid + logger.info(f"Client connected: {session_id}") + + # Initialize conversation data + if session_id not in active_conversations: + active_conversations[session_id] = Conversation(session_id) + user_queues[session_id] = queue.Queue() + processing_threads[session_id] = threading.Thread( + target=process_audio_queue, + args=(session_id, user_queues[session_id], app, socketio, models, active_conversations, DEVICE), + daemon=True + ) + processing_threads[session_id].start() + + emit('connection_status', {'status': 'connected'}) + + @socketio.on('disconnect') + def handle_disconnect(reason=None): + """Handle client disconnection""" + session_id = request.sid + logger.info(f"Client disconnected: {session_id}. Reason: {reason}") + + # Cleanup + if session_id in active_conversations: + # Mark for deletion rather than immediately removing + # as the processing thread might still be accessing it + active_conversations[session_id].is_processing = False + user_queues[session_id].put(None) # Signal thread to terminate + + @socketio.on('audio_data') + def handle_audio_data(data): + """Handle incoming audio data""" + session_id = request.sid + logger.info(f"Received audio data from {session_id}") + + # Check if the models are loaded + if models.generator is None or models.whisperx_model is None or models.llm is None: + emit('error', {'message': 'Models still loading, please wait'}) + return + + # Check if we're already processing for this session + if session_id in active_conversations and active_conversations[session_id].is_processing: + emit('error', {'message': 'Still processing previous audio, please wait'}) + return + + # Add to processing queue + if session_id in user_queues: + user_queues[session_id].put(data) + else: + emit('error', {'message': 'Session not initialized, please refresh the page'}) + +def process_audio_queue(session_id, q, app, socketio, models, active_conversations, DEVICE): + """Background thread to process audio chunks for a session""" + logger.info(f"Started processing thread for session: {session_id}") + + try: + while session_id in active_conversations: + try: + # Get the next audio chunk with a timeout + data = q.get(timeout=120) + if data is None: # Termination signal + break + + # Process the audio and generate a response + process_audio_and_respond(session_id, data, app, socketio, models, active_conversations, DEVICE) + + except queue.Empty: + # Timeout, check if session is still valid + continue + except Exception as e: + logger.error(f"Error processing audio for {session_id}: {str(e)}") + # Create an app context for the socket emit + with app.app_context(): + socketio.emit('error', {'message': str(e)}, room=session_id) + finally: + logger.info(f"Ending processing thread for session: {session_id}") + # Clean up when thread is done + with app.app_context(): + if session_id in active_conversations: + del active_conversations[session_id] + if session_id in user_queues: + del user_queues[session_id] + +def process_audio_and_respond(session_id, data, app, socketio, models, active_conversations, DEVICE): + """Process audio data and generate a response using WhisperX""" + if models.generator is None or models.whisperx_model is None or models.llm is None: + logger.warning("Models not yet loaded!") + with app.app_context(): + socketio.emit('error', {'message': 'Models still loading, please wait'}, room=session_id) + return + + logger.info(f"Processing audio for session {session_id}") + conversation = active_conversations[session_id] + + try: + # Set processing flag + conversation.is_processing = True + + # Process base64 audio data + audio_data = data['audio'] + speaker_id = data['speaker'] + logger.info(f"Received audio from speaker {speaker_id}") + + # Convert from base64 to WAV + try: + audio_bytes = base64.b64decode(audio_data.split(',')[1]) + logger.info(f"Decoded audio bytes: {len(audio_bytes)} bytes") + except Exception as e: + logger.error(f"Error decoding base64 audio: {str(e)}") + raise + + # Save to temporary file for processing + with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as temp_file: + temp_file.write(audio_bytes) + temp_path = temp_file.name + + try: + # Notify client that transcription is starting + with app.app_context(): + socketio.emit('processing_status', {'status': 'transcribing'}, room=session_id) + + # Load audio using WhisperX + import whisperx + audio = whisperx.load_audio(temp_path) + + # Check audio length and add a warning for short clips + audio_length = len(audio) / 16000 # assuming 16kHz sample rate + if audio_length < 1.0: + logger.warning(f"Audio is very short ({audio_length:.2f}s), may affect transcription quality") + + # Transcribe using WhisperX + batch_size = 16 # adjust based on your GPU memory + logger.info("Running WhisperX transcription...") + + # Handle the warning about audio being shorter than 30s by suppressing it + import warnings + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", message="audio is shorter than 30s") + result = models.whisperx_model.transcribe(audio, batch_size=batch_size) + + # Get the detected language + language_code = result["language"] + logger.info(f"Detected language: {language_code}") + + # Check if alignment model needs to be loaded or updated + if models.whisperx_align_model is None or language_code != models.last_language: + # Clean up old models if they exist + if models.whisperx_align_model is not None: + del models.whisperx_align_model + del models.whisperx_align_metadata + if DEVICE == "cuda": + gc.collect() + torch.cuda.empty_cache() + + # Load new alignment model for the detected language + logger.info(f"Loading alignment model for language: {language_code}") + models.whisperx_align_model, models.whisperx_align_metadata = whisperx.load_align_model( + language_code=language_code, device=DEVICE + ) + models.last_language = language_code + + # Align the transcript to get word-level timestamps + if result["segments"] and len(result["segments"]) > 0: + logger.info("Aligning transcript...") + result = whisperx.align( + result["segments"], + models.whisperx_align_model, + models.whisperx_align_metadata, + audio, + DEVICE, + return_char_alignments=False + ) + + # Process the segments for better output + for segment in result["segments"]: + # Round timestamps for better display + segment["start"] = round(segment["start"], 2) + segment["end"] = round(segment["end"], 2) + # Add a confidence score if not present + if "confidence" not in segment: + segment["confidence"] = 1.0 # Default confidence + + # Extract the full text from all segments + user_text = ' '.join([segment['text'] for segment in result['segments']]) + + # If no text was recognized, don't process further + if not user_text or len(user_text.strip()) == 0: + with app.app_context(): + socketio.emit('error', {'message': 'No speech detected'}, room=session_id) + return + + logger.info(f"Transcription: {user_text}") + + # Load audio for CSM input + waveform, sample_rate = torchaudio.load(temp_path) + + # Normalize to mono if needed + if waveform.shape[0] > 1: + waveform = torch.mean(waveform, dim=0, keepdim=True) + + # Resample to the CSM sample rate if needed + if sample_rate != models.generator.sample_rate: + waveform = torchaudio.functional.resample( + waveform, + orig_freq=sample_rate, + new_freq=models.generator.sample_rate + ) + + # Add the user's message to conversation history + user_segment = conversation.add_segment( + text=user_text, + speaker=speaker_id, + audio=waveform.squeeze() + ) + + # Send transcription to client with detailed segments + with app.app_context(): + socketio.emit('transcription', { + 'text': user_text, + 'speaker': speaker_id, + 'segments': result['segments'] # Include the detailed segments with timestamps + }, room=session_id) + + # Generate AI response using Llama + with app.app_context(): + socketio.emit('processing_status', {'status': 'generating'}, room=session_id) + + # Create prompt from conversation history + conversation_history = "" + for segment in conversation.segments[-5:]: # Last 5 segments for context + role = "User" if segment.speaker == 0 else "Assistant" + conversation_history += f"{role}: {segment.text}\n" + + # Add final prompt + prompt = f"{conversation_history}Assistant: " + + # Generate response with Llama + try: + # Ensure pad token is set + if models.tokenizer.pad_token is None: + models.tokenizer.pad_token = models.tokenizer.eos_token + + input_tokens = models.tokenizer( + prompt, + return_tensors="pt", + padding=True, + return_attention_mask=True + ) + input_ids = input_tokens.input_ids.to(DEVICE) + attention_mask = input_tokens.attention_mask.to(DEVICE) + + with torch.no_grad(): + generated_ids = models.llm.generate( + input_ids, + attention_mask=attention_mask, + max_new_tokens=100, + temperature=0.7, + top_p=0.9, + do_sample=True, + pad_token_id=models.tokenizer.eos_token_id + ) + + # Decode the response + response_text = models.tokenizer.decode( + generated_ids[0][input_ids.shape[1]:], + skip_special_tokens=True + ).strip() + except Exception as e: + logger.error(f"Error generating response: {str(e)}") + logger.error(traceback.format_exc()) + response_text = "I'm sorry, I encountered an error while processing your request." + + # Synthesize speech + with app.app_context(): + socketio.emit('processing_status', {'status': 'synthesizing'}, room=session_id) + + # Start sending the audio response + socketio.emit('audio_response_start', { + 'text': response_text, + 'total_chunks': 1, + 'chunk_index': 0 + }, room=session_id) + + # Define AI speaker ID + ai_speaker_id = conversation.ai_speaker_id + + # Generate audio + audio_tensor = models.generator.generate( + text=response_text, + speaker=ai_speaker_id, + context=conversation.get_context(), + max_audio_length_ms=10_000, + temperature=0.9 + ) + + # Add AI response to conversation history + ai_segment = conversation.add_segment( + text=response_text, + speaker=ai_speaker_id, + audio=audio_tensor + ) + + # Convert audio to WAV format + with io.BytesIO() as wav_io: + torchaudio.save( + wav_io, + audio_tensor.unsqueeze(0).cpu(), + models.generator.sample_rate, + format="wav" + ) + wav_io.seek(0) + wav_data = wav_io.read() + + # Convert WAV data to base64 + audio_base64 = f"data:audio/wav;base64,{base64.b64encode(wav_data).decode('utf-8')}" + + # Send audio chunk to client + with app.app_context(): + socketio.emit('audio_response_chunk', { + 'chunk': audio_base64, + 'chunk_index': 0, + 'total_chunks': 1, + 'is_last': True + }, room=session_id) + + # Signal completion + socketio.emit('audio_response_complete', { + 'text': response_text + }, room=session_id) + + finally: + # Clean up temp file + if os.path.exists(temp_path): + os.unlink(temp_path) + + except Exception as e: + logger.error(f"Error processing audio: {str(e)}") + logger.error(traceback.format_exc()) + with app.app_context(): + socketio.emit('error', {'message': f'Error: {str(e)}'}, room=session_id) + finally: + # Reset processing flag + conversation.is_processing = False \ No newline at end of file diff --git a/Backend/watermarking.py b/Backend/api/watermarking.py similarity index 100% rename from Backend/watermarking.py rename to Backend/api/watermarking.py diff --git a/Backend/index.html b/Backend/index.html deleted file mode 100644 index 9950a00..0000000 --- a/Backend/index.html +++ /dev/null @@ -1,419 +0,0 @@ - - - - - - CSM Voice Chat - - - - - - -
-

CSM Voice Chat

-

Talk naturally with the AI using your voice

-
- -
-
-
-

Conversation

-
-
- Disconnected -
- - -
- 0% -
- - -
-
CSM
-
ASR
-
LLM
-
-
-
-
- -
-
-

Controls

-

Click the button below to start and stop recording.

-
- - -
- - -
- -
Start speaking to see audio visualization
-
-
-
-
-
- -
-

Settings

-
-
- - -
-
- - -
-
- - -
-
- - -
-
-
-
-
- - - - - - - \ No newline at end of file diff --git a/Backend/requirements.txt b/Backend/requirements.txt deleted file mode 100644 index 1e05eb3..0000000 --- a/Backend/requirements.txt +++ /dev/null @@ -1,13 +0,0 @@ -flask==2.2.5 -flask-socketio==5.3.6 -flask-cors==4.0.0 -torch==2.4.0 -torchaudio==2.4.0 -tokenizers==0.21.0 -transformers==4.49.0 -librosa==0.10.1 -huggingface_hub==0.28.1 -moshi==0.2.2 -torchtune==0.4.0 -torchao==0.9.0 -silentcipher @ git+https://github.com/SesameAILabs/silentcipher@master \ No newline at end of file diff --git a/Backend/run_csm.py b/Backend/run_csm.py deleted file mode 100644 index 0062973..0000000 --- a/Backend/run_csm.py +++ /dev/null @@ -1,117 +0,0 @@ -import os -import torch -import torchaudio -from huggingface_hub import hf_hub_download -from generator import load_csm_1b, Segment -from dataclasses import dataclass - -# Disable Triton compilation -os.environ["NO_TORCH_COMPILE"] = "1" - -# Default prompts are available at https://hf.co/sesame/csm-1b -prompt_filepath_conversational_a = hf_hub_download( - repo_id="sesame/csm-1b", - filename="prompts/conversational_a.wav" -) -prompt_filepath_conversational_b = hf_hub_download( - repo_id="sesame/csm-1b", - filename="prompts/conversational_b.wav" -) - -SPEAKER_PROMPTS = { - "conversational_a": { - "text": ( - "like revising for an exam I'd have to try and like keep up the momentum because I'd " - "start really early I'd be like okay I'm gonna start revising now and then like " - "you're revising for ages and then I just like start losing steam I didn't do that " - "for the exam we had recently to be fair that was a more of a last minute scenario " - "but like yeah I'm trying to like yeah I noticed this yesterday that like Mondays I " - "sort of start the day with this not like a panic but like a" - ), - "audio": prompt_filepath_conversational_a - }, - "conversational_b": { - "text": ( - "like a super Mario level. Like it's very like high detail. And like, once you get " - "into the park, it just like, everything looks like a computer game and they have all " - "these, like, you know, if, if there's like a, you know, like in a Mario game, they " - "will have like a question block. And if you like, you know, punch it, a coin will " - "come out. So like everyone, when they come into the park, they get like this little " - "bracelet and then you can go punching question blocks around." - ), - "audio": prompt_filepath_conversational_b - } -} - -def load_prompt_audio(audio_path: str, target_sample_rate: int) -> torch.Tensor: - audio_tensor, sample_rate = torchaudio.load(audio_path) - audio_tensor = audio_tensor.squeeze(0) - # Resample is lazy so we can always call it - audio_tensor = torchaudio.functional.resample( - audio_tensor, orig_freq=sample_rate, new_freq=target_sample_rate - ) - return audio_tensor - -def prepare_prompt(text: str, speaker: int, audio_path: str, sample_rate: int) -> Segment: - audio_tensor = load_prompt_audio(audio_path, sample_rate) - return Segment(text=text, speaker=speaker, audio=audio_tensor) - -def main(): - # Select the best available device, skipping MPS due to float64 limitations - if torch.cuda.is_available(): - device = "cuda" - else: - device = "cpu" - print(f"Using device: {device}") - - # Load model - generator = load_csm_1b(device) - - # Prepare prompts - prompt_a = prepare_prompt( - SPEAKER_PROMPTS["conversational_a"]["text"], - 0, - SPEAKER_PROMPTS["conversational_a"]["audio"], - generator.sample_rate - ) - - prompt_b = prepare_prompt( - SPEAKER_PROMPTS["conversational_b"]["text"], - 1, - SPEAKER_PROMPTS["conversational_b"]["audio"], - generator.sample_rate - ) - - # Generate conversation - conversation = [ - {"text": "Hey how are you doing?", "speaker_id": 0}, - {"text": "Pretty good, pretty good. How about you?", "speaker_id": 1}, - {"text": "I'm great! So happy to be speaking with you today.", "speaker_id": 0}, - {"text": "Me too! This is some cool stuff, isn't it?", "speaker_id": 1} - ] - - # Generate each utterance - generated_segments = [] - prompt_segments = [prompt_a, prompt_b] - - for utterance in conversation: - print(f"Generating: {utterance['text']}") - audio_tensor = generator.generate( - text=utterance['text'], - speaker=utterance['speaker_id'], - context=prompt_segments + generated_segments, - max_audio_length_ms=10_000, - ) - generated_segments.append(Segment(text=utterance['text'], speaker=utterance['speaker_id'], audio=audio_tensor)) - - # Concatenate all generations - all_audio = torch.cat([seg.audio for seg in generated_segments], dim=0) - torchaudio.save( - "full_conversation.wav", - all_audio.unsqueeze(0).cpu(), - generator.sample_rate - ) - print("Successfully generated full_conversation.wav") - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/Backend/server.py b/Backend/server.py index e912a9d..b8af6b7 100644 --- a/Backend/server.py +++ b/Backend/server.py @@ -1,639 +1,19 @@ -import os -import io -import base64 -import time -import json -import uuid -import logging -import threading -import queue -import tempfile -import gc -from typing import Dict, List, Optional, Tuple +""" +CSM Voice Chat Server +A voice chat application that uses CSM 1B for voice synthesis, +WhisperX for speech recognition, and Llama 3.2 for language generation. +""" -import torch -import torchaudio -import numpy as np -from flask import Flask, request, jsonify, send_from_directory -from flask_socketio import SocketIO, emit -from flask_cors import CORS -from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline +# Start the Flask application +from api.app import app, socketio -# Import WhisperX for better transcription -import whisperx - -from generator import load_csm_1b, Segment -from dataclasses import dataclass - -# Add these imports at the top -import psutil -import gc - -# Configure logging -logging.basicConfig(level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') -logger = logging.getLogger(__name__) - -# Initialize Flask app -app = Flask(__name__, static_folder='.') -CORS(app) -socketio = SocketIO(app, cors_allowed_origins="*", ping_timeout=120) - -# Configure device -if torch.cuda.is_available(): - DEVICE = "cuda" -elif torch.backends.mps.is_available(): - DEVICE = "mps" -else: - DEVICE = "cpu" - -logger.info(f"Using device: {DEVICE}") - -# Global variables -active_conversations = {} -user_queues = {} -processing_threads = {} - -# Load models -@dataclass -class AppModels: - generator = None - tokenizer = None - llm = None - whisperx_model = None - whisperx_align_model = None - whisperx_align_metadata = None - diarize_model = None - last_language = None - -# Initialize the models object -models = AppModels() - -def load_models(): - """Load all required models""" - global models - - socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 0}) - - # CSM 1B loading - try: - socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 10, 'message': 'Loading CSM voice model'}) - models.generator = load_csm_1b(device=DEVICE) - logger.info("CSM 1B model loaded successfully") - socketio.emit('model_status', {'model': 'csm', 'status': 'loaded'}) - socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 33}) - if DEVICE == "cuda": - torch.cuda.empty_cache() - except Exception as e: - import traceback - error_details = traceback.format_exc() - logger.error(f"Error loading CSM 1B model: {str(e)}\n{error_details}") - socketio.emit('model_status', {'model': 'csm', 'status': 'error', 'message': str(e)}) - - # WhisperX loading - try: - socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 40, 'message': 'Loading speech recognition model'}) - # Use WhisperX for better transcription with timestamps - import whisperx - - # Use compute_type based on device - compute_type = "float16" if DEVICE == "cuda" else "float32" - - # Load the WhisperX model (smaller model for faster processing) - models.whisperx_model = whisperx.load_model("small", DEVICE, compute_type=compute_type) - - logger.info("WhisperX model loaded successfully") - socketio.emit('model_status', {'model': 'asr', 'status': 'loaded'}) - socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 66}) - if DEVICE == "cuda": - torch.cuda.empty_cache() - except Exception as e: - import traceback - error_details = traceback.format_exc() - logger.error(f"Error loading WhisperX model: {str(e)}\n{error_details}") - socketio.emit('model_status', {'model': 'asr', 'status': 'error', 'message': str(e)}) - - # Llama loading - try: - socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 70, 'message': 'Loading language model'}) - models.llm = AutoModelForCausalLM.from_pretrained( - "meta-llama/Llama-3.2-1B", - device_map=DEVICE, - torch_dtype=torch.bfloat16 - ) - models.tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B") - - # Configure all special tokens - models.tokenizer.pad_token = models.tokenizer.eos_token - models.tokenizer.padding_side = "left" # For causal language modeling - - # Inform the model about the pad token - if hasattr(models.llm.config, "pad_token_id") and models.llm.config.pad_token_id is None: - models.llm.config.pad_token_id = models.tokenizer.pad_token_id - - logger.info("Llama 3.2 model loaded successfully") - socketio.emit('model_status', {'model': 'llm', 'status': 'loaded'}) - socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 100, 'message': 'All models loaded successfully'}) - socketio.emit('model_status', {'model': 'overall', 'status': 'loaded'}) - except Exception as e: - logger.error(f"Error loading Llama 3.2 model: {str(e)}") - socketio.emit('model_status', {'model': 'llm', 'status': 'error', 'message': str(e)}) - -# Load models in a background thread -threading.Thread(target=load_models, daemon=True).start() - -# Conversation data structure -class Conversation: - def __init__(self, session_id): - self.session_id = session_id - self.segments: List[Segment] = [] - self.current_speaker = 0 - self.ai_speaker_id = 1 # Add this property - self.last_activity = time.time() - self.is_processing = False - - def add_segment(self, text, speaker, audio): - segment = Segment(text=text, speaker=speaker, audio=audio) - self.segments.append(segment) - self.last_activity = time.time() - return segment - - def get_context(self, max_segments=10): - """Return the most recent segments for context""" - return self.segments[-max_segments:] if self.segments else [] - -# Routes -@app.route('/') -def index(): - return send_from_directory('.', 'index.html') - -@app.route('/voice-chat.js') -def voice_chat_js(): - return send_from_directory('.', 'voice-chat.js') - -@app.route('/api/health') -def health_check(): - return jsonify({ - "status": "ok", - "cuda_available": torch.cuda.is_available(), - "models_loaded": models.generator is not None and models.llm is not None - }) - -# Fix the system_status function: - -@app.route('/api/status') -def system_status(): - return jsonify({ - "status": "ok", - "cuda_available": torch.cuda.is_available(), - "device": DEVICE, - "models": { - "generator": models.generator is not None, - "asr": models.whisperx_model is not None, # Use the correct model name - "llm": models.llm is not None - } - }) - -# Add a new endpoint to check system resources -@app.route('/api/system_resources') -def system_resources(): - # Get CPU usage - cpu_percent = psutil.cpu_percent(interval=0.1) - - # Get memory usage - memory = psutil.virtual_memory() - memory_used_gb = memory.used / (1024 ** 3) - memory_total_gb = memory.total / (1024 ** 3) - memory_percent = memory.percent - - # Get GPU memory if available - gpu_memory = {} - if torch.cuda.is_available(): - for i in range(torch.cuda.device_count()): - gpu_memory[f"gpu_{i}"] = { - "allocated": torch.cuda.memory_allocated(i) / (1024 ** 3), - "reserved": torch.cuda.memory_reserved(i) / (1024 ** 3), - "max_allocated": torch.cuda.max_memory_allocated(i) / (1024 ** 3) - } - - return jsonify({ - "cpu_percent": cpu_percent, - "memory": { - "used_gb": memory_used_gb, - "total_gb": memory_total_gb, - "percent": memory_percent - }, - "gpu_memory": gpu_memory, - "active_sessions": len(active_conversations) - }) - -# Socket event handlers -@socketio.on('connect') -def handle_connect(auth=None): - session_id = request.sid - logger.info(f"Client connected: {session_id}") - - # Initialize conversation data - if session_id not in active_conversations: - active_conversations[session_id] = Conversation(session_id) - user_queues[session_id] = queue.Queue() - processing_threads[session_id] = threading.Thread( - target=process_audio_queue, - args=(session_id, user_queues[session_id]), - daemon=True - ) - processing_threads[session_id].start() - - emit('connection_status', {'status': 'connected'}) - -@socketio.on('disconnect') -def handle_disconnect(reason=None): - session_id = request.sid - logger.info(f"Client disconnected: {session_id}. Reason: {reason}") - - # Cleanup - if session_id in active_conversations: - # Mark for deletion rather than immediately removing - # as the processing thread might still be accessing it - active_conversations[session_id].is_processing = False - user_queues[session_id].put(None) # Signal thread to terminate - -@socketio.on('start_stream') -def handle_start_stream(): - session_id = request.sid - logger.info(f"Starting stream for client: {session_id}") - emit('streaming_status', {'status': 'active'}) - -@socketio.on('stop_stream') -def handle_stop_stream(): - session_id = request.sid - logger.info(f"Stopping stream for client: {session_id}") - emit('streaming_status', {'status': 'inactive'}) - -@socketio.on('clear_context') -def handle_clear_context(): - session_id = request.sid - logger.info(f"Clearing context for client: {session_id}") - - if session_id in active_conversations: - active_conversations[session_id].segments = [] - emit('context_updated', {'status': 'cleared'}) - -@socketio.on('audio_chunk') -def handle_audio_chunk(data): - session_id = request.sid - audio_data = data.get('audio', '') - speaker_id = int(data.get('speaker', 0)) - - if not audio_data or not session_id in active_conversations: - return - - # Update the current speaker - active_conversations[session_id].current_speaker = speaker_id - - # Queue audio for processing - user_queues[session_id].put({ - 'audio': audio_data, - 'speaker': speaker_id - }) - - emit('processing_status', {'status': 'transcribing'}) - -def process_audio_queue(session_id, q): - """Background thread to process audio chunks for a session""" - logger.info(f"Started processing thread for session: {session_id}") - - try: - while session_id in active_conversations: - try: - # Get the next audio chunk with a timeout - data = q.get(timeout=120) - if data is None: # Termination signal - break - - # Process the audio and generate a response - process_audio_and_respond(session_id, data) - - except queue.Empty: - # Timeout, check if session is still valid - continue - except Exception as e: - logger.error(f"Error processing audio for {session_id}: {str(e)}") - # Create an app context for the socket emit - with app.app_context(): - socketio.emit('error', {'message': str(e)}, room=session_id) - finally: - logger.info(f"Ending processing thread for session: {session_id}") - # Clean up when thread is done - with app.app_context(): - if session_id in active_conversations: - del active_conversations[session_id] - if session_id in user_queues: - del user_queues[session_id] - -def process_audio_and_respond(session_id, data): - """Process audio data and generate a response using WhisperX""" - if models.generator is None or models.whisperx_model is None or models.llm is None: - logger.warning("Models not yet loaded!") - with app.app_context(): - socketio.emit('error', {'message': 'Models still loading, please wait'}, room=session_id) - return - - logger.info(f"Processing audio for session {session_id}") - conversation = active_conversations[session_id] - - try: - # Set processing flag - conversation.is_processing = True - - # Process base64 audio data - audio_data = data['audio'] - speaker_id = data['speaker'] - logger.info(f"Received audio from speaker {speaker_id}") - - # Convert from base64 to WAV - try: - audio_bytes = base64.b64decode(audio_data.split(',')[1]) - logger.info(f"Decoded audio bytes: {len(audio_bytes)} bytes") - except Exception as e: - logger.error(f"Error decoding base64 audio: {str(e)}") - raise - - # Save to temporary file for processing - with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as temp_file: - temp_file.write(audio_bytes) - temp_path = temp_file.name - - try: - # Notify client that transcription is starting - with app.app_context(): - socketio.emit('processing_status', {'status': 'transcribing'}, room=session_id) - - # Load audio using WhisperX - import whisperx - audio = whisperx.load_audio(temp_path) - - # Check audio length and add a warning for short clips - audio_length = len(audio) / 16000 # assuming 16kHz sample rate - if audio_length < 1.0: - logger.warning(f"Audio is very short ({audio_length:.2f}s), may affect transcription quality") - - # Transcribe using WhisperX - batch_size = 16 # adjust based on your GPU memory - logger.info("Running WhisperX transcription...") - - # Handle the warning about audio being shorter than 30s by suppressing it - import warnings - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", message="audio is shorter than 30s") - result = models.whisperx_model.transcribe(audio, batch_size=batch_size) - - # Get the detected language - language_code = result["language"] - logger.info(f"Detected language: {language_code}") - - # Check if alignment model needs to be loaded or updated - if models.whisperx_align_model is None or language_code != models.last_language: - # Clean up old models if they exist - if models.whisperx_align_model is not None: - del models.whisperx_align_model - del models.whisperx_align_metadata - if DEVICE == "cuda": - gc.collect() - torch.cuda.empty_cache() - - # Load new alignment model for the detected language - logger.info(f"Loading alignment model for language: {language_code}") - models.whisperx_align_model, models.whisperx_align_metadata = whisperx.load_align_model( - language_code=language_code, device=DEVICE - ) - models.last_language = language_code - - # Align the transcript to get word-level timestamps - if result["segments"] and len(result["segments"]) > 0: - logger.info("Aligning transcript...") - result = whisperx.align( - result["segments"], - models.whisperx_align_model, - models.whisperx_align_metadata, - audio, - DEVICE, - return_char_alignments=False - ) - - # Process the segments for better output - for segment in result["segments"]: - # Round timestamps for better display - segment["start"] = round(segment["start"], 2) - segment["end"] = round(segment["end"], 2) - # Add a confidence score if not present - if "confidence" not in segment: - segment["confidence"] = 1.0 # Default confidence - - # Extract the full text from all segments - user_text = ' '.join([segment['text'] for segment in result['segments']]) - - # If no text was recognized, don't process further - if not user_text or len(user_text.strip()) == 0: - with app.app_context(): - socketio.emit('error', {'message': 'No speech detected'}, room=session_id) - return - - logger.info(f"Transcription: {user_text}") - - # Load audio for CSM input - waveform, sample_rate = torchaudio.load(temp_path) - - # Normalize to mono if needed - if waveform.shape[0] > 1: - waveform = torch.mean(waveform, dim=0, keepdim=True) - - # Resample to the CSM sample rate if needed - if sample_rate != models.generator.sample_rate: - waveform = torchaudio.functional.resample( - waveform, - orig_freq=sample_rate, - new_freq=models.generator.sample_rate - ) - - # Add the user's message to conversation history - user_segment = conversation.add_segment( - text=user_text, - speaker=speaker_id, - audio=waveform.squeeze() - ) - - # Send transcription to client with detailed segments - with app.app_context(): - socketio.emit('transcription', { - 'text': user_text, - 'speaker': speaker_id, - 'segments': result['segments'] # Include the detailed segments with timestamps - }, room=session_id) - - # Generate AI response using Llama - with app.app_context(): - socketio.emit('processing_status', {'status': 'generating'}, room=session_id) - - # Create prompt from conversation history - conversation_history = "" - for segment in conversation.segments[-5:]: # Last 5 segments for context - role = "User" if segment.speaker == 0 else "Assistant" - conversation_history += f"{role}: {segment.text}\n" - - # Add final prompt - prompt = f"{conversation_history}Assistant: " - - # Generate response with Llama - try: - # Ensure pad token is set - if models.tokenizer.pad_token is None: - models.tokenizer.pad_token = models.tokenizer.eos_token - - input_tokens = models.tokenizer( - prompt, - return_tensors="pt", - padding=True, - return_attention_mask=True - ) - input_ids = input_tokens.input_ids.to(DEVICE) - attention_mask = input_tokens.attention_mask.to(DEVICE) - - with torch.no_grad(): - generated_ids = models.llm.generate( - input_ids, - attention_mask=attention_mask, - max_new_tokens=100, - temperature=0.7, - top_p=0.9, - do_sample=True, - pad_token_id=models.tokenizer.eos_token_id - ) - - # Decode the response - response_text = models.tokenizer.decode( - generated_ids[0][input_ids.shape[1]:], - skip_special_tokens=True - ).strip() - except Exception as e: - logger.error(f"Error generating response: {str(e)}") - import traceback - logger.error(traceback.format_exc()) - response_text = "I'm sorry, I encountered an error while processing your request." - - # Synthesize speech - with app.app_context(): - socketio.emit('processing_status', {'status': 'synthesizing'}, room=session_id) - - # Start sending the audio response - socketio.emit('audio_response_start', { - 'text': response_text, - 'total_chunks': 1, - 'chunk_index': 0 - }, room=session_id) - - # Define AI speaker ID - ai_speaker_id = conversation.ai_speaker_id - - # Generate audio - audio_tensor = models.generator.generate( - text=response_text, - speaker=ai_speaker_id, - context=conversation.get_context(), - max_audio_length_ms=10_000, - temperature=0.9 - ) - - # Add AI response to conversation history - ai_segment = conversation.add_segment( - text=response_text, - speaker=ai_speaker_id, - audio=audio_tensor - ) - - # Convert audio to WAV format - with io.BytesIO() as wav_io: - torchaudio.save( - wav_io, - audio_tensor.unsqueeze(0).cpu(), - models.generator.sample_rate, - format="wav" - ) - wav_io.seek(0) - wav_data = wav_io.read() - - # Convert WAV data to base64 - audio_base64 = f"data:audio/wav;base64,{base64.b64encode(wav_data).decode('utf-8')}" - - # Send audio chunk to client - with app.app_context(): - socketio.emit('audio_response_chunk', { - 'chunk': audio_base64, - 'chunk_index': 0, - 'total_chunks': 1, - 'is_last': True - }, room=session_id) - - # Signal completion - socketio.emit('audio_response_complete', { - 'text': response_text - }, room=session_id) - - finally: - # Clean up temp file - if os.path.exists(temp_path): - os.unlink(temp_path) - - except Exception as e: - logger.error(f"Error processing audio: {str(e)}") - import traceback - logger.error(traceback.format_exc()) - with app.app_context(): - socketio.emit('error', {'message': f'Error: {str(e)}'}, room=session_id) - finally: - # Reset processing flag - conversation.is_processing = False - -# Error handler -@socketio.on_error() -def error_handler(e): - logger.error(f"SocketIO error: {str(e)}") - -# Periodic cleanup of inactive sessions -def cleanup_inactive_sessions(): - """Remove sessions that have been inactive for too long""" - current_time = time.time() - inactive_timeout = 3600 # 1 hour - - for session_id in list(active_conversations.keys()): - conversation = active_conversations[session_id] - if (current_time - conversation.last_activity > inactive_timeout and - not conversation.is_processing): - - logger.info(f"Cleaning up inactive session: {session_id}") - - # Signal processing thread to terminate - if session_id in user_queues: - user_queues[session_id].put(None) - - # Remove from active conversations - del active_conversations[session_id] - -# Start cleanup thread -def start_cleanup_thread(): - while True: - try: - cleanup_inactive_sessions() - except Exception as e: - logger.error(f"Error in cleanup thread: {str(e)}") - time.sleep(300) # Run every 5 minutes - -cleanup_thread = threading.Thread(target=start_cleanup_thread, daemon=True) -cleanup_thread.start() - -# Start the server if __name__ == '__main__': + import os + port = int(os.environ.get('PORT', 5000)) debug_mode = os.environ.get('DEBUG', 'False').lower() == 'true' - logger.info(f"Starting server on port {port} (debug={debug_mode})") + + print(f"Starting server on port {port} (debug={debug_mode})") + print("Visit http://localhost:5000 to access the application") + socketio.run(app, host='0.0.0.0', port=port, debug=debug_mode, allow_unsafe_werkzeug=True) \ No newline at end of file diff --git a/Backend/setup.py b/Backend/setup.py deleted file mode 100644 index 8eddb95..0000000 --- a/Backend/setup.py +++ /dev/null @@ -1,13 +0,0 @@ -from setuptools import setup, find_packages -import os - -# Read requirements from requirements.txt -with open('requirements.txt') as f: - requirements = [line.strip() for line in f if line.strip() and not line.startswith('#')] - -setup( - name='csm', - version='0.1.0', - packages=find_packages(), - install_requires=requirements, -) diff --git a/Backend/voice-chat.js b/Backend/voice-chat.js deleted file mode 100644 index dc2db04..0000000 --- a/Backend/voice-chat.js +++ /dev/null @@ -1,1054 +0,0 @@ -/** - * CSM AI Voice Chat Client - * - * A web client that connects to a CSM AI voice chat server and enables - * real-time voice conversation with an AI assistant. - */ - -// Configuration constants -const SERVER_URL = window.location.hostname === 'localhost' ? - 'http://localhost:5000' : window.location.origin; -const ENERGY_WINDOW_SIZE = 15; -const CLIENT_SILENCE_DURATION_MS = 750; - -// DOM elements -const elements = { - conversation: document.getElementById('conversation'), - streamButton: document.getElementById('streamButton'), - clearButton: document.getElementById('clearButton'), - thresholdSlider: document.getElementById('thresholdSlider'), - thresholdValue: document.getElementById('thresholdValue'), - visualizerCanvas: document.getElementById('audioVisualizer'), - visualizerLabel: document.getElementById('visualizerLabel'), - volumeLevel: document.getElementById('volumeLevel'), - statusDot: document.getElementById('statusDot'), - statusText: document.getElementById('statusText'), - speakerSelection: document.getElementById('speakerSelect'), - autoPlayResponses: document.getElementById('autoPlayResponses'), - showVisualizer: document.getElementById('showVisualizer') -}; - -// Application state -const state = { - socket: null, - audioContext: null, - analyser: null, - microphone: null, - streamProcessor: null, - isStreaming: false, - isSpeaking: false, - silenceThreshold: 0.01, - energyWindow: [], - silenceTimer: null, - volumeUpdateInterval: null, - visualizerAnimationFrame: null, - currentSpeaker: 0, - aiSpeakerId: 1, // Define the AI's speaker ID to match server.py - transcriptionRetries: 0, - maxTranscriptionRetries: 3 -}; - -// Visualizer variables -let canvasContext = null; -let visualizerBufferLength = 0; -let visualizerDataArray = null; - -// Audio streaming state -const streamingAudio = { - messageElement: null, - audioElement: null, - chunks: [], - totalChunks: 0, - receivedChunks: 0, - text: '', - complete: false -}; - -// Initialize the application -function initializeApp() { - // Initialize the UI elements - initializeUIElements(); - - // Initialize socket.io connection - setupSocketConnection(); - - // Setup event listeners - setupEventListeners(); - - // Initialize visualizer - setupVisualizer(); - - // Show welcome message - addSystemMessage('Welcome to CSM Voice Chat! Click "Start Conversation" to begin.'); -} - -// Initialize UI elements -function initializeUIElements() { - // Update threshold display - if (elements.thresholdValue) { - elements.thresholdValue.textContent = state.silenceThreshold.toFixed(3); - } -} - -// Setup Socket.IO connection -function setupSocketConnection() { - state.socket = io(SERVER_URL); - - // Connection events - state.socket.on('connect', () => { - updateConnectionStatus(true); - addSystemMessage('Connected to server.'); - }); - - state.socket.on('disconnect', () => { - updateConnectionStatus(false); - addSystemMessage('Disconnected from server.'); - stopStreaming(false); - }); - - state.socket.on('error', (data) => { - console.error('Server error:', data.message); - - // Make the error more user-friendly - let userMessage = data.message; - - // Check for common errors and provide more helpful messages - if (data.message.includes('Models still loading')) { - userMessage = 'The AI models are still loading. Please wait a moment and try again.'; - } else if (data.message.includes('No speech detected')) { - userMessage = 'No speech was detected. Please speak clearly and try again.'; - } - - addSystemMessage(`Error: ${userMessage}`); - - // Reset button state if it was processing - if (elements.streamButton.classList.contains('processing')) { - elements.streamButton.classList.remove('processing'); - elements.streamButton.innerHTML = ' Start Conversation'; - } - }); - - // Register message handlers - state.socket.on('transcription', handleTranscription); - state.socket.on('context_updated', handleContextUpdate); - state.socket.on('streaming_status', handleStreamingStatus); - state.socket.on('processing_status', handleProcessingStatus); - - // Add model status handlers - state.socket.on('model_status', handleModelStatusUpdate); - - // Handlers for incremental audio streaming - state.socket.on('audio_response_start', handleAudioResponseStart); - state.socket.on('audio_response_chunk', handleAudioResponseChunk); - state.socket.on('audio_response_complete', handleAudioResponseComplete); -} - -// Setup event listeners -function setupEventListeners() { - // Stream button - elements.streamButton.addEventListener('click', toggleStreaming); - - // Clear button - elements.clearButton.addEventListener('click', clearConversation); - - // Threshold slider - if (elements.thresholdSlider) { - elements.thresholdSlider.addEventListener('input', updateThreshold); - } - - // Speaker selection - elements.speakerSelection.addEventListener('change', () => { - state.currentSpeaker = parseInt(elements.speakerSelection.value); - }); - - // Visualizer toggle - if (elements.showVisualizer) { - elements.showVisualizer.addEventListener('change', toggleVisualizerVisibility); - } -} - -// Setup audio visualizer -function setupVisualizer() { - if (!elements.visualizerCanvas) return; - - canvasContext = elements.visualizerCanvas.getContext('2d'); - - // Set canvas dimensions - elements.visualizerCanvas.width = elements.visualizerCanvas.offsetWidth; - elements.visualizerCanvas.height = elements.visualizerCanvas.offsetHeight; - - // Initialize visualization data array - visualizerDataArray = new Uint8Array(128); - - // Start the visualizer animation - drawVisualizer(); -} - -// Update connection status UI -function updateConnectionStatus(isConnected) { - if (isConnected) { - elements.statusDot.classList.add('active'); - elements.statusText.textContent = 'Connected'; - } else { - elements.statusDot.classList.remove('active'); - elements.statusText.textContent = 'Disconnected'; - } -} - -// Toggle streaming state -function toggleStreaming() { - if (state.isStreaming) { - stopStreaming(); - } else { - startStreaming(); - } -} - -// Start streaming audio to the server -function startStreaming() { - if (!state.socket || !state.socket.connected) { - addSystemMessage('Not connected to server. Please refresh the page.'); - return; - } - - // Check if models are loaded via the API - fetch('/api/status') - .then(response => response.json()) - .then(data => { - if (!data.models.generator || !data.models.asr || !data.models.llm) { - addSystemMessage('Still loading AI models. Please wait...'); - return; - } - - // Continue with recording if models are loaded - initializeRecording(); - }) - .catch(error => { - console.error('Error checking model status:', error); - // Try anyway, the server will respond with an error if models aren't ready - initializeRecording(); - }); -} - -// Extracted the recording initialization to a separate function -function initializeRecording() { - // Request microphone access - navigator.mediaDevices.getUserMedia({ audio: true, video: false }) - .then(stream => { - state.isStreaming = true; - elements.streamButton.classList.add('recording'); - elements.streamButton.innerHTML = ' Stop Recording'; - - // Initialize Web Audio API - state.audioContext = new (window.AudioContext || window.webkitAudioContext)({ sampleRate: 16000 }); - state.microphone = state.audioContext.createMediaStreamSource(stream); - state.analyser = state.audioContext.createAnalyser(); - state.analyser.fftSize = 2048; - - // Setup analyzer for visualizer - visualizerBufferLength = state.analyser.frequencyBinCount; - visualizerDataArray = new Uint8Array(visualizerBufferLength); - - state.microphone.connect(state.analyser); - - // Create processor node for audio data - const processorNode = state.audioContext.createScriptProcessor(4096, 1, 1); - processorNode.onaudioprocess = handleAudioProcess; - state.analyser.connect(processorNode); - processorNode.connect(state.audioContext.destination); - state.streamProcessor = processorNode; - - state.silenceTimer = null; - state.energyWindow = []; - state.isSpeaking = false; - - // Notify server - state.socket.emit('start_stream'); - - // Start volume meter updates - state.volumeUpdateInterval = setInterval(updateVolumeMeter, 100); - - // Make sure visualizer is visible if enabled - if (elements.showVisualizer && elements.showVisualizer.checked) { - elements.visualizerLabel.style.opacity = '0'; - } - - addSystemMessage('Recording started. Speak now...'); - }) - .catch(error => { - console.error('Error accessing microphone:', error); - addSystemMessage('Could not access microphone. Please check permissions.'); - }); -} - -// Stop streaming audio -function stopStreaming(notifyServer = true) { - if (state.isStreaming) { - state.isStreaming = false; - elements.streamButton.classList.remove('recording'); - elements.streamButton.classList.remove('processing'); - elements.streamButton.innerHTML = ' Start Conversation'; - - // Clean up audio resources - if (state.streamProcessor) { - state.streamProcessor.disconnect(); - state.streamProcessor = null; - } - - if (state.analyser) { - state.analyser.disconnect(); - state.analyser = null; - } - - if (state.microphone) { - state.microphone.disconnect(); - state.microphone = null; - } - - if (state.audioContext) { - state.audioContext.close().catch(err => console.warn('Error closing audio context:', err)); - state.audioContext = null; - } - - // Clear any pending silence timer - if (state.silenceTimer) { - clearTimeout(state.silenceTimer); - state.silenceTimer = null; - } - - // Clear volume meter updates - if (state.volumeUpdateInterval) { - clearInterval(state.volumeUpdateInterval); - state.volumeUpdateInterval = null; - - // Reset volume meter - if (elements.volumeLevel) { - elements.volumeLevel.style.width = '0%'; - } - } - - // Show visualizer label - if (elements.visualizerLabel) { - elements.visualizerLabel.style.opacity = '0.7'; - } - - // Notify server if needed - if (notifyServer && state.socket && state.socket.connected) { - state.socket.emit('stop_stream'); - } - - addSystemMessage('Recording stopped.'); - } -} - -// Handle audio processing -function handleAudioProcess(event) { - if (!state.isStreaming) return; - - const inputData = event.inputBuffer.getChannelData(0); - const energy = calculateAudioEnergy(inputData); - updateEnergyWindow(energy); - - const averageEnergy = calculateAverageEnergy(); - const isSilent = averageEnergy < state.silenceThreshold; - - handleSpeechState(isSilent); -} - -// Calculate audio energy (volume) -function calculateAudioEnergy(buffer) { - let sum = 0; - for (let i = 0; i < buffer.length; i++) { - sum += buffer[i] * buffer[i]; - } - return Math.sqrt(sum / buffer.length); -} - -// Update energy window for averaging -function updateEnergyWindow(energy) { - state.energyWindow.push(energy); - if (state.energyWindow.length > ENERGY_WINDOW_SIZE) { - state.energyWindow.shift(); - } -} - -// Calculate average energy from window -function calculateAverageEnergy() { - if (state.energyWindow.length === 0) return 0; - - const sum = state.energyWindow.reduce((acc, val) => acc + val, 0); - return sum / state.energyWindow.length; -} - -// Update the threshold from the slider -function updateThreshold() { - state.silenceThreshold = parseFloat(elements.thresholdSlider.value); - elements.thresholdValue.textContent = state.silenceThreshold.toFixed(3); -} - -// Update the volume meter display -function updateVolumeMeter() { - if (!state.isStreaming || !state.energyWindow.length || !elements.volumeLevel) return; - - const avgEnergy = calculateAverageEnergy(); - - // Scale energy to percentage (0-100) - // Energy values are typically very small (e.g., 0.001 to 0.1) - const scaleFactor = 1000; - const percentage = Math.min(100, Math.max(0, avgEnergy * scaleFactor)); - - // Update volume meter width - elements.volumeLevel.style.width = `${percentage}%`; - - // Change color based on level - if (percentage > 70) { - elements.volumeLevel.style.backgroundColor = '#ff5252'; - } else if (percentage > 30) { - elements.volumeLevel.style.backgroundColor = '#4CAF50'; - } else { - elements.volumeLevel.style.backgroundColor = '#4c84ff'; - } -} - -// Handle speech/silence state transitions -function handleSpeechState(isSilent) { - if (state.isSpeaking) { - if (isSilent) { - // User was speaking but now is silent - if (!state.silenceTimer) { - state.silenceTimer = setTimeout(() => { - // Silence lasted long enough, consider speech done - if (state.isSpeaking) { - state.isSpeaking = false; - - try { - // Get the current audio data and send it - const audioBuffer = new Float32Array(state.audioContext.sampleRate * 5); // 5 seconds max - state.analyser.getFloatTimeDomainData(audioBuffer); - - // Check if audio has content - const hasAudioContent = audioBuffer.some(sample => Math.abs(sample) > 0.01); - - if (!hasAudioContent) { - console.warn('Audio buffer appears to be empty or very quiet'); - - if (state.transcriptionRetries < state.maxTranscriptionRetries) { - state.transcriptionRetries++; - const retryMessage = `No speech detected (attempt ${state.transcriptionRetries}/${state.maxTranscriptionRetries}). Please speak louder and try again.`; - addSystemMessage(retryMessage); - } else { - state.transcriptionRetries = 0; - addSystemMessage('Multiple attempts failed to detect speech. Please check your microphone and try again.'); - } - return; - } - - // Create WAV blob - const wavBlob = createWavBlob(audioBuffer, state.audioContext.sampleRate); - - // Convert to base64 - const reader = new FileReader(); - reader.onloadend = function() { - sendAudioChunk(reader.result, state.currentSpeaker); - }; - reader.readAsDataURL(wavBlob); - - // Update button state - elements.streamButton.classList.add('processing'); - elements.streamButton.innerHTML = ' Processing...'; - - addSystemMessage('Processing your message...'); - } catch (e) { - console.error('Error recording audio:', e); - addSystemMessage('Error recording audio. Please try again.'); - } - } - }, CLIENT_SILENCE_DURATION_MS); - } - } else { - // User is still speaking, reset silence timer - if (state.silenceTimer) { - clearTimeout(state.silenceTimer); - state.silenceTimer = null; - } - } - } else { - if (!isSilent) { - // User started speaking - state.isSpeaking = true; - if (state.silenceTimer) { - clearTimeout(state.silenceTimer); - state.silenceTimer = null; - } - } - } -} - -// Send audio chunk to server -function sendAudioChunk(audioData, speaker) { - if (state.socket && state.socket.connected) { - state.socket.emit('audio_chunk', { - audio: audioData, - speaker: speaker - }); - } -} - -// Create WAV blob from audio data -function createWavBlob(audioData, sampleRate) { - const numChannels = 1; - const bitsPerSample = 16; - const bytesPerSample = bitsPerSample / 8; - - // Create buffer for WAV file - const buffer = new ArrayBuffer(44 + audioData.length * bytesPerSample); - const view = new DataView(buffer); - - // Write WAV header - // "RIFF" chunk descriptor - writeString(view, 0, 'RIFF'); - view.setUint32(4, 36 + audioData.length * bytesPerSample, true); - writeString(view, 8, 'WAVE'); - - // "fmt " sub-chunk - writeString(view, 12, 'fmt '); - view.setUint32(16, 16, true); // subchunk1size - view.setUint16(20, 1, true); // audio format (PCM) - view.setUint16(22, numChannels, true); - view.setUint32(24, sampleRate, true); - view.setUint32(28, sampleRate * numChannels * bytesPerSample, true); // byte rate - view.setUint16(32, numChannels * bytesPerSample, true); // block align - view.setUint16(34, bitsPerSample, true); - - // "data" sub-chunk - writeString(view, 36, 'data'); - view.setUint32(40, audioData.length * bytesPerSample, true); - - // Write audio data - const audioDataStart = 44; - for (let i = 0; i < audioData.length; i++) { - const sample = Math.max(-1, Math.min(1, audioData[i])); - const value = sample < 0 ? sample * 0x8000 : sample * 0x7FFF; - view.setInt16(audioDataStart + i * bytesPerSample, value, true); - } - - return new Blob([buffer], { type: 'audio/wav' }); -} - -// Helper function to write strings to DataView -function writeString(view, offset, string) { - for (let i = 0; i < string.length; i++) { - view.setUint8(offset + i, string.charCodeAt(i)); - } -} - -// Clear conversation history -function clearConversation() { - elements.conversation.innerHTML = ''; - if (state.socket && state.socket.connected) { - state.socket.emit('clear_context'); - } - addSystemMessage('Conversation cleared.'); -} - -// Draw audio visualizer -function drawVisualizer() { - if (!canvasContext || !elements.visualizerCanvas) { - state.visualizerAnimationFrame = requestAnimationFrame(drawVisualizer); - return; - } - - state.visualizerAnimationFrame = requestAnimationFrame(drawVisualizer); - - // Skip drawing if visualizer is hidden or not enabled - if (elements.showVisualizer && !elements.showVisualizer.checked) { - if (elements.visualizerCanvas.style.opacity !== '0') { - elements.visualizerCanvas.style.opacity = '0'; - } - return; - } else if (elements.visualizerCanvas.style.opacity !== '1') { - elements.visualizerCanvas.style.opacity = '1'; - } - - // Get frequency data if available - if (state.isStreaming && state.analyser) { - try { - state.analyser.getByteFrequencyData(visualizerDataArray); - } catch (e) { - console.warn('Error getting frequency data:', e); - } - } else { - // Fade out when not streaming - for (let i = 0; i < visualizerDataArray.length; i++) { - visualizerDataArray[i] = Math.max(0, visualizerDataArray[i] - 5); - } - } - - // Clear canvas - canvasContext.fillStyle = 'rgb(0, 0, 0)'; - canvasContext.fillRect(0, 0, elements.visualizerCanvas.width, elements.visualizerCanvas.height); - - // Draw gradient bars - const width = elements.visualizerCanvas.width; - const height = elements.visualizerCanvas.height; - const barCount = Math.min(visualizerBufferLength, 64); - const barWidth = width / barCount - 1; - - for (let i = 0; i < barCount; i++) { - const index = Math.floor(i * visualizerBufferLength / barCount); - const value = visualizerDataArray[index]; - - // Use logarithmic scale for better audio visualization - const logFactor = 20; - const scaledValue = Math.log(1 + (value / 255) * logFactor) / Math.log(1 + logFactor); - const barHeight = scaledValue * height; - - // Position bars - const x = i * (barWidth + 1); - const y = height - barHeight; - - // Create color gradient based on frequency and amplitude - const hue = i / barCount * 360; // Full color spectrum - const saturation = 80 + (value / 255 * 20); // Higher values more saturated - const lightness = 40 + (value / 255 * 20); // Dynamic brightness - - // Draw main bar - canvasContext.fillStyle = `hsl(${hue}, ${saturation}%, ${lightness}%)`; - canvasContext.fillRect(x, y, barWidth, barHeight); - - // Add highlight effect - if (barHeight > 5) { - const gradient = canvasContext.createLinearGradient( - x, y, - x, y + barHeight * 0.5 - ); - gradient.addColorStop(0, `hsla(${hue}, ${saturation}%, ${lightness + 20}%, 0.4)`); - gradient.addColorStop(1, `hsla(${hue}, ${saturation}%, ${lightness}%, 0)`); - canvasContext.fillStyle = gradient; - canvasContext.fillRect(x, y, barWidth, barHeight * 0.5); - - // Add highlight on top of the bar - canvasContext.fillStyle = `hsla(${hue}, ${saturation - 20}%, ${lightness + 30}%, 0.7)`; - canvasContext.fillRect(x, y, barWidth, 2); - } - } -} - -// Toggle visualizer visibility -function toggleVisualizerVisibility() { - const isVisible = elements.showVisualizer.checked; - elements.visualizerCanvas.style.opacity = isVisible ? '1' : '0'; -} - -// Add a message to the conversation -function addMessage(text, type) { - if (!elements.conversation) return; - - const messageDiv = document.createElement('div'); - messageDiv.className = `message ${type}`; - - const textElement = document.createElement('p'); - textElement.textContent = text; - messageDiv.appendChild(textElement); - - // Add timestamp to every message - const timestamp = new Date().toLocaleTimeString(); - const timeLabel = document.createElement('div'); - timeLabel.className = 'message-timestamp'; - timeLabel.textContent = timestamp; - messageDiv.appendChild(timeLabel); - - elements.conversation.appendChild(messageDiv); - - // Auto-scroll to the bottom - elements.conversation.scrollTop = elements.conversation.scrollHeight; - - return messageDiv; -} - -// Add a system message to the conversation -function addSystemMessage(text) { - if (!elements.conversation) return; - - const messageDiv = document.createElement('div'); - messageDiv.className = 'message system'; - messageDiv.textContent = text; - - elements.conversation.appendChild(messageDiv); - - // Auto-scroll to the bottom - elements.conversation.scrollTop = elements.conversation.scrollHeight; - - return messageDiv; -} - -// Handle transcription response from server -function handleTranscription(data) { - const speaker = data.speaker === 0 ? 'user' : 'ai'; - - // Create the message div - const messageDiv = addMessage(data.text, speaker); - - // If we have detailed segments from WhisperX, add timestamps - if (data.segments && data.segments.length > 0) { - // Add a timestamps container - const timestampsContainer = document.createElement('div'); - timestampsContainer.className = 'timestamps-container'; - timestampsContainer.style.display = 'none'; // Hidden by default - - // Add a toggle button - const toggleButton = document.createElement('button'); - toggleButton.className = 'timestamp-toggle'; - toggleButton.textContent = 'Show Timestamps'; - toggleButton.onclick = function() { - const isHidden = timestampsContainer.style.display === 'none'; - timestampsContainer.style.display = isHidden ? 'block' : 'none'; - toggleButton.textContent = isHidden ? 'Hide Timestamps' : 'Show Timestamps'; - }; - - // Add timestamps for each segment - data.segments.forEach(segment => { - const timestampDiv = document.createElement('div'); - timestampDiv.className = 'timestamp'; - - // Format start and end times - const startTime = formatTime(segment.start); - const endTime = formatTime(segment.end); - - timestampDiv.innerHTML = ` - [${startTime} - ${endTime}] - ${segment.text} - `; - - timestampsContainer.appendChild(timestampDiv); - }); - - // Add the timestamp elements to the message - messageDiv.appendChild(toggleButton); - messageDiv.appendChild(timestampsContainer); - } else { - // No timestamp data available - add a simple timestamp for the entire message - const timestamp = new Date().toLocaleTimeString(); - const timeLabel = document.createElement('div'); - timeLabel.className = 'simple-timestamp'; - timeLabel.textContent = timestamp; - messageDiv.appendChild(timeLabel); - } - - return messageDiv; -} - -// Helper function to format time in seconds to MM:SS.ms format -function formatTime(seconds) { - const mins = Math.floor(seconds / 60); - const secs = Math.floor(seconds % 60); - const ms = Math.floor((seconds % 1) * 1000); - - return `${mins.toString().padStart(2, '0')}:${secs.toString().padStart(2, '0')}.${ms.toString().padStart(3, '0')}`; -} - -// Handle context update from server -function handleContextUpdate(data) { - if (data.status === 'cleared') { - elements.conversation.innerHTML = ''; - addSystemMessage('Conversation context cleared.'); - } -} - -// Handle streaming status updates from server -function handleStreamingStatus(data) { - if (data.status === 'active') { - console.log('Server acknowledged streaming is active'); - } else if (data.status === 'inactive') { - console.log('Server acknowledged streaming is inactive'); - } -} - -// Handle processing status updates -function handleProcessingStatus(data) { - switch (data.status) { - case 'transcribing': - addSystemMessage('Transcribing your message...'); - break; - case 'generating': - addSystemMessage('Generating response...'); - break; - case 'synthesizing': - addSystemMessage('Synthesizing voice...'); - break; - } -} - -// Handle the start of an audio streaming response -function handleAudioResponseStart(data) { - console.log(`Expecting ${data.total_chunks} audio chunks`); - - // Reset streaming state - streamingAudio.chunks = []; - streamingAudio.totalChunks = data.total_chunks; - streamingAudio.receivedChunks = 0; - streamingAudio.text = data.text; - streamingAudio.complete = false; -} - -// Handle an incoming audio chunk -function handleAudioResponseChunk(data) { - // Create or update audio element for playback - const audioElement = document.createElement('audio'); - if (elements.autoPlayResponses.checked) { - audioElement.autoplay = true; - } - audioElement.controls = true; - audioElement.className = 'audio-player'; - audioElement.src = data.chunk; - - // Store the chunk - streamingAudio.chunks[data.chunk_index] = data.chunk; - streamingAudio.receivedChunks++; - - // Store audio element reference for later use - streamingAudio.audioElement = audioElement; - - // Add to the conversation - const messages = elements.conversation.querySelectorAll('.message.ai'); - if (messages.length > 0) { - const lastAiMessage = messages[messages.length - 1]; - streamingAudio.messageElement = lastAiMessage; - - // Replace existing audio player if there is one - const existingPlayer = lastAiMessage.querySelector('.audio-player'); - if (existingPlayer) { - lastAiMessage.replaceChild(audioElement, existingPlayer); - } else { - lastAiMessage.appendChild(audioElement); - } - } else { - // Create a new message for the AI response - const aiMessage = document.createElement('div'); - aiMessage.className = 'message ai'; - streamingAudio.messageElement = aiMessage; - - if (streamingAudio.text) { - const textElement = document.createElement('p'); - textElement.textContent = streamingAudio.text; - aiMessage.appendChild(textElement); - } - - aiMessage.appendChild(audioElement); - elements.conversation.appendChild(aiMessage); - } - - // Auto-scroll - elements.conversation.scrollTop = elements.conversation.scrollHeight; - - // If this is the last chunk or we've received all expected chunks - if (data.is_last || streamingAudio.receivedChunks >= streamingAudio.totalChunks) { - streamingAudio.complete = true; - - // Reset stream button if we're still streaming - if (state.isStreaming) { - elements.streamButton.classList.remove('processing'); - elements.streamButton.innerHTML = ' Listening...'; - } - } -} - -// Handle completion of audio streaming -function handleAudioResponseComplete(data) { - console.log('Audio response complete:', data); - streamingAudio.complete = true; - - // Make sure we finalize the audio even if some chunks were missed - finalizeStreamingAudio(); - - // Update UI to normal state - if (state.isStreaming) { - elements.streamButton.innerHTML = ' Listening...'; - elements.streamButton.classList.add('recording'); - elements.streamButton.classList.remove('processing'); - } -} - -// Finalize streaming audio by combining chunks and updating the UI -function finalizeStreamingAudio() { - if (!streamingAudio.messageElement || streamingAudio.chunks.length === 0) { - return; - } - - try { - // For more sophisticated audio streaming, you would need to properly concatenate - // the WAV files, but for now we'll use the last chunk as the complete audio - // since it should contain the entire response due to how the server is implementing it - const lastChunkIndex = streamingAudio.chunks.length - 1; - const audioData = streamingAudio.chunks[lastChunkIndex] || streamingAudio.chunks[0]; - - // Update the audio element with the complete audio - if (streamingAudio.audioElement) { - streamingAudio.audioElement.src = audioData; - - // Auto-play if enabled and not already playing - if (elements.autoPlayResponses && elements.autoPlayResponses.checked && - streamingAudio.audioElement.paused) { - streamingAudio.audioElement.play() - .catch(err => { - console.warn('Auto-play failed:', err); - addSystemMessage('Auto-play failed. Please click play to hear the response.'); - }); - } - } - - // Remove loading indicator and processing class - if (streamingAudio.messageElement) { - const loadingElement = streamingAudio.messageElement.querySelector('.loading-indicator'); - if (loadingElement) { - streamingAudio.messageElement.removeChild(loadingElement); - } - streamingAudio.messageElement.classList.remove('processing'); - } - - console.log('Audio response finalized and ready for playback'); - } catch (e) { - console.error('Error finalizing streaming audio:', e); - } - - // Reset streaming audio state - streamingAudio.chunks = []; - streamingAudio.totalChunks = 0; - streamingAudio.receivedChunks = 0; - streamingAudio.messageElement = null; - streamingAudio.audioElement = null; -} - -// Enhance the handleModelStatusUpdate function: - -function handleModelStatusUpdate(data) { - const { model, status, message, progress } = data; - - if (model === 'overall' && status === 'loading') { - // Update overall loading progress - const progressBar = document.getElementById('modelLoadingProgress'); - if (progressBar) { - progressBar.value = progress; - progressBar.textContent = `${progress}%`; - } - return; - } - - if (status === 'loaded') { - console.log(`Model ${model} loaded successfully`); - addSystemMessage(`${model.toUpperCase()} model loaded successfully`); - - // Update UI to show model is ready - const modelStatusElement = document.getElementById(`${model}Status`); - if (modelStatusElement) { - modelStatusElement.classList.remove('loading'); - modelStatusElement.classList.add('loaded'); - modelStatusElement.title = 'Model loaded successfully'; - } - - // Check if the required models are loaded to enable conversation - checkAllModelsLoaded(); - } else if (status === 'error') { - console.error(`Error loading ${model} model: ${message}`); - addSystemMessage(`Error loading ${model.toUpperCase()} model: ${message}`); - - // Update UI to show model loading failed - const modelStatusElement = document.getElementById(`${model}Status`); - if (modelStatusElement) { - modelStatusElement.classList.remove('loading'); - modelStatusElement.classList.add('error'); - modelStatusElement.title = `Error: ${message}`; - } - } -} - -// Check if all required models are loaded and enable UI accordingly -function checkAllModelsLoaded() { - // When all models are loaded, enable the stream button if it was disabled - const allLoaded = - document.getElementById('csmStatus')?.classList.contains('loaded') && - document.getElementById('asrStatus')?.classList.contains('loaded') && - document.getElementById('llmStatus')?.classList.contains('loaded'); - - if (allLoaded) { - elements.streamButton.disabled = false; - addSystemMessage('All models loaded. Ready for conversation!'); - } -} - -// Add CSS styles for new UI elements -document.addEventListener('DOMContentLoaded', function() { - // Add styles for processing state and timestamps - const style = document.createElement('style'); - style.textContent = ` - .message.processing { - opacity: 0.8; - } - - .loading-indicator { - display: flex; - align-items: center; - margin-top: 8px; - font-size: 0.9em; - color: #666; - } - - .loading-spinner { - width: 16px; - height: 16px; - border: 2px solid #ddd; - border-top: 2px solid var(--primary-color); - border-radius: 50%; - margin-right: 8px; - animation: spin 1s linear infinite; - } - - @keyframes spin { - 0% { transform: rotate(0deg); } - 100% { transform: rotate(360deg); } - } - - /* Timestamp styles */ - .timestamp-toggle { - font-size: 0.75em; - padding: 4px 8px; - margin-top: 8px; - background-color: #f0f0f0; - border: 1px solid #ddd; - border-radius: 4px; - cursor: pointer; - } - - .timestamp-toggle:hover { - background-color: #e0e0e0; - } - - .timestamps-container { - margin-top: 8px; - padding: 8px; - background-color: #f9f9f9; - border-radius: 4px; - font-size: 0.85em; - } - - .timestamp { - margin-bottom: 4px; - padding: 2px 0; - } - - .timestamp .time { - color: #666; - font-family: monospace; - margin-right: 8px; - } - - .timestamp .text { - color: #333; - } - `; - document.head.appendChild(style); -}); - -// Initialize the application when DOM is fully loaded -document.addEventListener('DOMContentLoaded', initializeApp); - diff --git a/React/src/app/auth/session/route.ts b/React/src/app/auth/session/route.ts new file mode 100644 index 0000000..9299d4a --- /dev/null +++ b/React/src/app/auth/session/route.ts @@ -0,0 +1,12 @@ +import { NextResponse } from "next/server"; +import { auth0 } from "../../../lib/auth0"; + +export async function GET() { + try { + const session = await auth0.getSession(); + return NextResponse.json({ session }); + } catch (error) { + console.error("Error getting session:", error); + return NextResponse.json({ session: null }, { status: 500 }); + } +} diff --git a/React/src/app/page.tsx b/React/src/app/page.tsx index c91ed2b..21e0862 100644 --- a/React/src/app/page.tsx +++ b/React/src/app/page.tsx @@ -1,67 +1,94 @@ -import { useState } from "react"; -import { auth0 } from "../lib/auth0"; +"use client"; +import { useState, useEffect } from "react"; +import { useRouter } from "next/navigation"; - -export default async function Home() { - +export default function Home() { const [contacts, setContacts] = useState([]); const [codeword, setCodeword] = useState(""); + const [session, setSession] = useState(null); + const [loading, setLoading] = useState(true); + const router = useRouter(); - const session = await auth0.getSession(); - - console.log("Session:", session?.user); + useEffect(() => { + // Fetch session data from an API route + fetch("/auth/session") + .then((response) => response.json()) + .then((data) => { + setSession(data.session); + setLoading(false); + }) + .catch((error) => { + console.error("Failed to fetch session:", error); + setLoading(false); + }); + }, []); function saveToDB() { - //e.preventDefault(); alert("Saving contacts..."); - // const contactInputs = document.querySelectorAll(".text-input") as NodeListOf; - // const contactValues = Array.from(contactInputs).map(input => input.value); - // console.log("Contact values:", contactValues); - // // save codeword and contacts to database - // fetch("/api/databaseStorage", { - // method: "POST", - // headers: { - // "Content-Type": "application/json", - // }, - // body: JSON.stringify({ - // email: session?.user?.email || "", - // codeword: (document.getElementById("codeword") as HTMLInputElement)?.value, - // contacts: contactValues, - // }), - // }) - // .then((response) => { - // if (response.ok) { - // alert("Contacts saved successfully!"); - // } else { - // alert("Error saving contacts."); - // } - // }) - // .catch((error) => { - // console.error("Error:", error); - // alert("Error saving contacts."); - // }); - + const contactInputs = document.querySelectorAll( + ".text-input" + ) as NodeListOf; + const contactValues = Array.from(contactInputs).map((input) => input.value); + + fetch("/api/databaseStorage", { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify({ + email: session?.user?.email || "", + codeword: codeword, + contacts: contactValues, + }), + }) + .then((response) => { + if (response.ok) { + alert("Contacts saved successfully!"); + } else { + alert("Error saving contacts."); + } + }) + .catch((error) => { + console.error("Error:", error); + alert("Error saving contacts."); + }); } + if (loading) { + return
Loading...
; + } // If no session, show sign-up and login buttons - if (!session) { - + if (!session) { return (
- + - +
-

Fauxcall

-

Set emergency contacts

-

If you stop speaking or say the codeword, these contacts will be notified

+

+ Fauxcall +

+

+ Set emergency contacts +

+

+ If you stop speaking or say the codeword, these contacts will be + notified +

{/* form for setting codeword */} -
e.preventDefault()}> + e.preventDefault()} + > + className="bg-blue-500 text-white font-semibold font-lg rounded-md p-2" + type="submit" + > + Set codeword +
{/* form for adding contacts */} -
e.preventDefault()}> + e.preventDefault()} + > - +
); @@ -107,25 +145,42 @@ export default async function Home() {

Welcome, {session.user.name}!

- -

Fauxcall

-

Set emergency contacts

-

If you stop speaking or say the codeword, these contacts will be notified

- {/* form for setting codeword */} -
e.preventDefault()}> - setCodeword(e.target.value)} - placeholder="Codeword" - className="border border-gray-300 rounded-md p-2" - /> - -
- {/* form for adding contacts */} -
e.preventDefault()}> + type="submit" + > + Set codeword + +
+ {/* form for adding contacts */} +
e.preventDefault()} + > - - - + + -
- + className="bg-slate-500 text-yellow-300 text-stretch-50% font-lg rounded-md p-2" + > + Save + + +

@@ -182,6 +248,4 @@ export default async function Home() {

); - - }