diff --git a/Backend/api/app.py b/Backend/api/app.py
new file mode 100644
index 0000000..018061f
--- /dev/null
+++ b/Backend/api/app.py
@@ -0,0 +1,136 @@
+import os
+import logging
+import threading
+from dataclasses import dataclass
+from flask import Flask
+from flask_socketio import SocketIO
+from flask_cors import CORS
+
+# Configure logging
+logging.basicConfig(level=logging.INFO,
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
+logger = logging.getLogger(__name__)
+
+# Configure device
+import torch
+if torch.cuda.is_available():
+ DEVICE = "cuda"
+elif torch.backends.mps.is_available():
+ DEVICE = "mps"
+else:
+ DEVICE = "cpu"
+
+logger.info(f"Using device: {DEVICE}")
+
+# Initialize Flask app
+app = Flask(__name__, static_folder='../', static_url_path='')
+CORS(app)
+socketio = SocketIO(app, cors_allowed_origins="*", ping_timeout=120)
+
+# Global variables for conversation state
+active_conversations = {}
+user_queues = {}
+processing_threads = {}
+
+# Model storage
+@dataclass
+class AppModels:
+ generator = None
+ tokenizer = None
+ llm = None
+ whisperx_model = None
+ whisperx_align_model = None
+ whisperx_align_metadata = None
+ last_language = None
+
+models = AppModels()
+
+def load_models():
+ """Load all required models"""
+ from generator import load_csm_1b
+ import whisperx
+ import gc
+ from transformers import AutoModelForCausalLM, AutoTokenizer
+ global models
+
+ socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 0})
+
+ # CSM 1B loading
+ try:
+ socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 10, 'message': 'Loading CSM voice model'})
+ models.generator = load_csm_1b(device=DEVICE)
+ logger.info("CSM 1B model loaded successfully")
+ socketio.emit('model_status', {'model': 'csm', 'status': 'loaded'})
+ socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 33})
+ if DEVICE == "cuda":
+ torch.cuda.empty_cache()
+ except Exception as e:
+ import traceback
+ error_details = traceback.format_exc()
+ logger.error(f"Error loading CSM 1B model: {str(e)}\n{error_details}")
+ socketio.emit('model_status', {'model': 'csm', 'status': 'error', 'message': str(e)})
+
+ # WhisperX loading
+ try:
+ socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 40, 'message': 'Loading speech recognition model'})
+ # Use WhisperX for better transcription with timestamps
+ # Use compute_type based on device
+ compute_type = "float16" if DEVICE == "cuda" else "float32"
+
+ # Load the WhisperX model (smaller model for faster processing)
+ models.whisperx_model = whisperx.load_model("small", DEVICE, compute_type=compute_type)
+
+ logger.info("WhisperX model loaded successfully")
+ socketio.emit('model_status', {'model': 'asr', 'status': 'loaded'})
+ socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 66})
+ if DEVICE == "cuda":
+ torch.cuda.empty_cache()
+ except Exception as e:
+ import traceback
+ error_details = traceback.format_exc()
+ logger.error(f"Error loading WhisperX model: {str(e)}\n{error_details}")
+ socketio.emit('model_status', {'model': 'asr', 'status': 'error', 'message': str(e)})
+
+ # Llama loading
+ try:
+ socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 70, 'message': 'Loading language model'})
+ models.llm = AutoModelForCausalLM.from_pretrained(
+ "meta-llama/Llama-3.2-1B",
+ device_map=DEVICE,
+ torch_dtype=torch.bfloat16
+ )
+ models.tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B")
+
+ # Configure all special tokens
+ models.tokenizer.pad_token = models.tokenizer.eos_token
+ models.tokenizer.padding_side = "left" # For causal language modeling
+
+ # Inform the model about the pad token
+ if hasattr(models.llm.config, "pad_token_id") and models.llm.config.pad_token_id is None:
+ models.llm.config.pad_token_id = models.tokenizer.pad_token_id
+
+ logger.info("Llama 3.2 model loaded successfully")
+ socketio.emit('model_status', {'model': 'llm', 'status': 'loaded'})
+ socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 100, 'message': 'All models loaded successfully'})
+ socketio.emit('model_status', {'model': 'overall', 'status': 'loaded'})
+ except Exception as e:
+ logger.error(f"Error loading Llama 3.2 model: {str(e)}")
+ socketio.emit('model_status', {'model': 'llm', 'status': 'error', 'message': str(e)})
+
+# Load models in a background thread
+threading.Thread(target=load_models, daemon=True).start()
+
+# Import routes and socket handlers
+from api.routes import register_routes
+from api.socket_handlers import register_handlers
+
+# Register routes and socket handlers
+register_routes(app)
+register_handlers(socketio, app, models, active_conversations, user_queues, processing_threads, DEVICE)
+
+# Run server if executed directly
+if __name__ == '__main__':
+ port = int(os.environ.get('PORT', 5000))
+ debug_mode = os.environ.get('DEBUG', 'False').lower() == 'true'
+ logger.info(f"Starting server on port {port} (debug={debug_mode})")
+ socketio.run(app, host='0.0.0.0', port=port, debug=debug_mode, allow_unsafe_werkzeug=True)
\ No newline at end of file
diff --git a/Backend/generator.py b/Backend/api/generator.py
similarity index 100%
rename from Backend/generator.py
rename to Backend/api/generator.py
diff --git a/Backend/models.py b/Backend/api/models.py
similarity index 100%
rename from Backend/models.py
rename to Backend/api/models.py
diff --git a/Backend/api/routes.py b/Backend/api/routes.py
new file mode 100644
index 0000000..af1bfce
--- /dev/null
+++ b/Backend/api/routes.py
@@ -0,0 +1,74 @@
+import os
+import torch
+import psutil
+from flask import send_from_directory, jsonify, request
+
+def register_routes(app):
+ """Register HTTP routes for the application"""
+
+ @app.route('/')
+ def index():
+ """Serve the main application page"""
+ return send_from_directory(app.static_folder, 'index.html')
+
+ @app.route('/voice-chat.js')
+ def serve_js():
+ """Serve the JavaScript file"""
+ return send_from_directory(app.static_folder, 'voice-chat.js')
+
+ @app.route('/api/status')
+ def system_status():
+ """Return the system status"""
+ # Import here to avoid circular imports
+ from api.app import models, DEVICE
+
+ return jsonify({
+ "status": "ok",
+ "cuda_available": torch.cuda.is_available(),
+ "device": DEVICE,
+ "models": {
+ "generator": models.generator is not None,
+ "asr": models.whisperx_model is not None,
+ "llm": models.llm is not None
+ },
+ "versions": {
+ "transformers": "4.49.0", # Replace with actual version
+ "torch": torch.__version__
+ }
+ })
+
+ @app.route('/api/system_resources')
+ def system_resources():
+ """Return system resource usage"""
+ # Import here to avoid circular imports
+ from api.app import active_conversations, DEVICE
+
+ # Get CPU usage
+ cpu_percent = psutil.cpu_percent(interval=0.1)
+
+ # Get memory usage
+ memory = psutil.virtual_memory()
+ memory_used_gb = memory.used / (1024 ** 3)
+ memory_total_gb = memory.total / (1024 ** 3)
+ memory_percent = memory.percent
+
+ # Get GPU memory if available
+ gpu_memory = {}
+ if torch.cuda.is_available():
+ for i in range(torch.cuda.device_count()):
+ gpu_memory[f"gpu_{i}"] = {
+ "allocated": torch.cuda.memory_allocated(i) / (1024 ** 3),
+ "reserved": torch.cuda.memory_reserved(i) / (1024 ** 3),
+ "max_allocated": torch.cuda.max_memory_allocated(i) / (1024 ** 3)
+ }
+
+ return jsonify({
+ "cpu_percent": cpu_percent,
+ "memory": {
+ "used_gb": memory_used_gb,
+ "total_gb": memory_total_gb,
+ "percent": memory_percent
+ },
+ "gpu_memory": gpu_memory,
+ "active_sessions": len(active_conversations)
+ })
\ No newline at end of file
diff --git a/Backend/api/socket_handlers.py b/Backend/api/socket_handlers.py
new file mode 100644
index 0000000..20513e9
--- /dev/null
+++ b/Backend/api/socket_handlers.py
@@ -0,0 +1,392 @@
+import os
+import io
+import base64
+import time
+import threading
+import queue
+import tempfile
+import gc
+import logging
+import traceback
+from typing import Dict, List, Optional
+
+import torch
+import torchaudio
+import numpy as np
+from flask import request
+from flask_socketio import emit
+
+# Import conversation model
+from generator import Segment
+
+logger = logging.getLogger(__name__)
+
+# Conversation data structure
+class Conversation:
+ def __init__(self, session_id):
+ self.session_id = session_id
+ self.segments: List[Segment] = []
+ self.current_speaker = 0
+ self.ai_speaker_id = 1 # Default AI speaker ID
+ self.last_activity = time.time()
+ self.is_processing = False
+
+ def add_segment(self, text, speaker, audio):
+ segment = Segment(text=text, speaker=speaker, audio=audio)
+ self.segments.append(segment)
+ self.last_activity = time.time()
+ return segment
+
+ def get_context(self, max_segments=10):
+ """Return the most recent segments for context"""
+ return self.segments[-max_segments:] if self.segments else []
+
+def register_handlers(socketio, app, models, active_conversations, user_queues, processing_threads, DEVICE):
+ """Register Socket.IO event handlers"""
+
+ @socketio.on('connect')
+ def handle_connect(auth=None):
+ """Handle client connection"""
+ session_id = request.sid
+ logger.info(f"Client connected: {session_id}")
+
+ # Initialize conversation data
+ if session_id not in active_conversations:
+ active_conversations[session_id] = Conversation(session_id)
+ user_queues[session_id] = queue.Queue()
+ processing_threads[session_id] = threading.Thread(
+ target=process_audio_queue,
+ args=(session_id, user_queues[session_id], app, socketio, models, active_conversations, DEVICE),
+ daemon=True
+ )
+ processing_threads[session_id].start()
+
+ emit('connection_status', {'status': 'connected'})
+
+ @socketio.on('disconnect')
+ def handle_disconnect(reason=None):
+ """Handle client disconnection"""
+ session_id = request.sid
+ logger.info(f"Client disconnected: {session_id}. Reason: {reason}")
+
+ # Cleanup
+ if session_id in active_conversations:
+ # Mark for deletion rather than immediately removing
+ # as the processing thread might still be accessing it
+ active_conversations[session_id].is_processing = False
+ user_queues[session_id].put(None) # Signal thread to terminate
+
+ @socketio.on('audio_data')
+ def handle_audio_data(data):
+ """Handle incoming audio data"""
+ session_id = request.sid
+ logger.info(f"Received audio data from {session_id}")
+
+ # Check if the models are loaded
+ if models.generator is None or models.whisperx_model is None or models.llm is None:
+ emit('error', {'message': 'Models still loading, please wait'})
+ return
+
+ # Check if we're already processing for this session
+ if session_id in active_conversations and active_conversations[session_id].is_processing:
+ emit('error', {'message': 'Still processing previous audio, please wait'})
+ return
+
+ # Add to processing queue
+ if session_id in user_queues:
+ user_queues[session_id].put(data)
+ else:
+ emit('error', {'message': 'Session not initialized, please refresh the page'})
+
+def process_audio_queue(session_id, q, app, socketio, models, active_conversations, DEVICE):
+ """Background thread to process audio chunks for a session"""
+ logger.info(f"Started processing thread for session: {session_id}")
+
+ try:
+ while session_id in active_conversations:
+ try:
+ # Get the next audio chunk with a timeout
+ data = q.get(timeout=120)
+ if data is None: # Termination signal
+ break
+
+ # Process the audio and generate a response
+ process_audio_and_respond(session_id, data, app, socketio, models, active_conversations, DEVICE)
+
+ except queue.Empty:
+ # Timeout, check if session is still valid
+ continue
+ except Exception as e:
+ logger.error(f"Error processing audio for {session_id}: {str(e)}")
+ # Create an app context for the socket emit
+ with app.app_context():
+ socketio.emit('error', {'message': str(e)}, room=session_id)
+ finally:
+ logger.info(f"Ending processing thread for session: {session_id}")
+ # Clean up when thread is done
+ with app.app_context():
+ if session_id in active_conversations:
+ del active_conversations[session_id]
+ if session_id in user_queues:
+ del user_queues[session_id]
+
+def process_audio_and_respond(session_id, data, app, socketio, models, active_conversations, DEVICE):
+ """Process audio data and generate a response using WhisperX"""
+ if models.generator is None or models.whisperx_model is None or models.llm is None:
+ logger.warning("Models not yet loaded!")
+ with app.app_context():
+ socketio.emit('error', {'message': 'Models still loading, please wait'}, room=session_id)
+ return
+
+ logger.info(f"Processing audio for session {session_id}")
+ conversation = active_conversations[session_id]
+
+ try:
+ # Set processing flag
+ conversation.is_processing = True
+
+ # Process base64 audio data
+ audio_data = data['audio']
+ speaker_id = data['speaker']
+ logger.info(f"Received audio from speaker {speaker_id}")
+
+ # Convert from base64 to WAV
+ try:
+ audio_bytes = base64.b64decode(audio_data.split(',')[1])
+ logger.info(f"Decoded audio bytes: {len(audio_bytes)} bytes")
+ except Exception as e:
+ logger.error(f"Error decoding base64 audio: {str(e)}")
+ raise
+
+ # Save to temporary file for processing
+ with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as temp_file:
+ temp_file.write(audio_bytes)
+ temp_path = temp_file.name
+
+ try:
+ # Notify client that transcription is starting
+ with app.app_context():
+ socketio.emit('processing_status', {'status': 'transcribing'}, room=session_id)
+
+ # Load audio using WhisperX
+ import whisperx
+ audio = whisperx.load_audio(temp_path)
+
+ # Check audio length and add a warning for short clips
+ audio_length = len(audio) / 16000 # assuming 16kHz sample rate
+ if audio_length < 1.0:
+ logger.warning(f"Audio is very short ({audio_length:.2f}s), may affect transcription quality")
+
+ # Transcribe using WhisperX
+ batch_size = 16 # adjust based on your GPU memory
+ logger.info("Running WhisperX transcription...")
+
+ # Handle the warning about audio being shorter than 30s by suppressing it
+ import warnings
+ with warnings.catch_warnings():
+ warnings.filterwarnings("ignore", message="audio is shorter than 30s")
+ result = models.whisperx_model.transcribe(audio, batch_size=batch_size)
+
+ # Get the detected language
+ language_code = result["language"]
+ logger.info(f"Detected language: {language_code}")
+
+ # Check if alignment model needs to be loaded or updated
+ if models.whisperx_align_model is None or language_code != models.last_language:
+ # Clean up old models if they exist
+ if models.whisperx_align_model is not None:
+ del models.whisperx_align_model
+ del models.whisperx_align_metadata
+ if DEVICE == "cuda":
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ # Load new alignment model for the detected language
+ logger.info(f"Loading alignment model for language: {language_code}")
+ models.whisperx_align_model, models.whisperx_align_metadata = whisperx.load_align_model(
+ language_code=language_code, device=DEVICE
+ )
+ models.last_language = language_code
+
+ # Align the transcript to get word-level timestamps
+ if result["segments"] and len(result["segments"]) > 0:
+ logger.info("Aligning transcript...")
+ result = whisperx.align(
+ result["segments"],
+ models.whisperx_align_model,
+ models.whisperx_align_metadata,
+ audio,
+ DEVICE,
+ return_char_alignments=False
+ )
+
+ # Process the segments for better output
+ for segment in result["segments"]:
+ # Round timestamps for better display
+ segment["start"] = round(segment["start"], 2)
+ segment["end"] = round(segment["end"], 2)
+ # Add a confidence score if not present
+ if "confidence" not in segment:
+ segment["confidence"] = 1.0 # Default confidence
+
+ # Extract the full text from all segments
+ user_text = ' '.join([segment['text'] for segment in result['segments']])
+
+ # If no text was recognized, don't process further
+ if not user_text or len(user_text.strip()) == 0:
+ with app.app_context():
+ socketio.emit('error', {'message': 'No speech detected'}, room=session_id)
+ return
+
+ logger.info(f"Transcription: {user_text}")
+
+ # Load audio for CSM input
+ waveform, sample_rate = torchaudio.load(temp_path)
+
+ # Normalize to mono if needed
+ if waveform.shape[0] > 1:
+ waveform = torch.mean(waveform, dim=0, keepdim=True)
+
+ # Resample to the CSM sample rate if needed
+ if sample_rate != models.generator.sample_rate:
+ waveform = torchaudio.functional.resample(
+ waveform,
+ orig_freq=sample_rate,
+ new_freq=models.generator.sample_rate
+ )
+
+ # Add the user's message to conversation history
+ user_segment = conversation.add_segment(
+ text=user_text,
+ speaker=speaker_id,
+ audio=waveform.squeeze()
+ )
+
+ # Send transcription to client with detailed segments
+ with app.app_context():
+ socketio.emit('transcription', {
+ 'text': user_text,
+ 'speaker': speaker_id,
+ 'segments': result['segments'] # Include the detailed segments with timestamps
+ }, room=session_id)
+
+ # Generate AI response using Llama
+ with app.app_context():
+ socketio.emit('processing_status', {'status': 'generating'}, room=session_id)
+
+ # Create prompt from conversation history
+ conversation_history = ""
+ for segment in conversation.segments[-5:]: # Last 5 segments for context
+ role = "User" if segment.speaker == 0 else "Assistant"
+ conversation_history += f"{role}: {segment.text}\n"
+
+ # Add final prompt
+ prompt = f"{conversation_history}Assistant: "
+
+ # Generate response with Llama
+ try:
+ # Ensure pad token is set
+ if models.tokenizer.pad_token is None:
+ models.tokenizer.pad_token = models.tokenizer.eos_token
+
+ input_tokens = models.tokenizer(
+ prompt,
+ return_tensors="pt",
+ padding=True,
+ return_attention_mask=True
+ )
+ input_ids = input_tokens.input_ids.to(DEVICE)
+ attention_mask = input_tokens.attention_mask.to(DEVICE)
+
+ with torch.no_grad():
+ generated_ids = models.llm.generate(
+ input_ids,
+ attention_mask=attention_mask,
+ max_new_tokens=100,
+ temperature=0.7,
+ top_p=0.9,
+ do_sample=True,
+ pad_token_id=models.tokenizer.eos_token_id
+ )
+
+ # Decode the response
+ response_text = models.tokenizer.decode(
+ generated_ids[0][input_ids.shape[1]:],
+ skip_special_tokens=True
+ ).strip()
+ except Exception as e:
+ logger.error(f"Error generating response: {str(e)}")
+ logger.error(traceback.format_exc())
+ response_text = "I'm sorry, I encountered an error while processing your request."
+
+ # Synthesize speech
+ with app.app_context():
+ socketio.emit('processing_status', {'status': 'synthesizing'}, room=session_id)
+
+ # Start sending the audio response
+ socketio.emit('audio_response_start', {
+ 'text': response_text,
+ 'total_chunks': 1,
+ 'chunk_index': 0
+ }, room=session_id)
+
+ # Define AI speaker ID
+ ai_speaker_id = conversation.ai_speaker_id
+
+ # Generate audio
+ audio_tensor = models.generator.generate(
+ text=response_text,
+ speaker=ai_speaker_id,
+ context=conversation.get_context(),
+ max_audio_length_ms=10_000,
+ temperature=0.9
+ )
+
+ # Add AI response to conversation history
+ ai_segment = conversation.add_segment(
+ text=response_text,
+ speaker=ai_speaker_id,
+ audio=audio_tensor
+ )
+
+ # Convert audio to WAV format
+ with io.BytesIO() as wav_io:
+ torchaudio.save(
+ wav_io,
+ audio_tensor.unsqueeze(0).cpu(),
+ models.generator.sample_rate,
+ format="wav"
+ )
+ wav_io.seek(0)
+ wav_data = wav_io.read()
+
+ # Convert WAV data to base64
+ audio_base64 = f"data:audio/wav;base64,{base64.b64encode(wav_data).decode('utf-8')}"
+
+ # Send audio chunk to client
+ with app.app_context():
+ socketio.emit('audio_response_chunk', {
+ 'chunk': audio_base64,
+ 'chunk_index': 0,
+ 'total_chunks': 1,
+ 'is_last': True
+ }, room=session_id)
+
+ # Signal completion
+ socketio.emit('audio_response_complete', {
+ 'text': response_text
+ }, room=session_id)
+
+ finally:
+ # Clean up temp file
+ if os.path.exists(temp_path):
+ os.unlink(temp_path)
+
+ except Exception as e:
+ logger.error(f"Error processing audio: {str(e)}")
+ logger.error(traceback.format_exc())
+ with app.app_context():
+ socketio.emit('error', {'message': f'Error: {str(e)}'}, room=session_id)
+ finally:
+ # Reset processing flag
+ conversation.is_processing = False
\ No newline at end of file
diff --git a/Backend/watermarking.py b/Backend/api/watermarking.py
similarity index 100%
rename from Backend/watermarking.py
rename to Backend/api/watermarking.py
diff --git a/Backend/index.html b/Backend/index.html
deleted file mode 100644
index 9950a00..0000000
--- a/Backend/index.html
+++ /dev/null
@@ -1,419 +0,0 @@
-
-
-
-
-
- CSM Voice Chat
-
-
-
-
-
-
-
-
-
-
-
-
-
-
Controls
-
Click the button below to start and stop recording.
-
-
-
-
-
-
-
-
-
Start speaking to see audio visualization
-
-
-
-
-
-
Settings
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
\ No newline at end of file
diff --git a/Backend/requirements.txt b/Backend/requirements.txt
deleted file mode 100644
index 1e05eb3..0000000
--- a/Backend/requirements.txt
+++ /dev/null
@@ -1,13 +0,0 @@
-flask==2.2.5
-flask-socketio==5.3.6
-flask-cors==4.0.0
-torch==2.4.0
-torchaudio==2.4.0
-tokenizers==0.21.0
-transformers==4.49.0
-librosa==0.10.1
-huggingface_hub==0.28.1
-moshi==0.2.2
-torchtune==0.4.0
-torchao==0.9.0
-silentcipher @ git+https://github.com/SesameAILabs/silentcipher@master
\ No newline at end of file
diff --git a/Backend/run_csm.py b/Backend/run_csm.py
deleted file mode 100644
index 0062973..0000000
--- a/Backend/run_csm.py
+++ /dev/null
@@ -1,117 +0,0 @@
-import os
-import torch
-import torchaudio
-from huggingface_hub import hf_hub_download
-from generator import load_csm_1b, Segment
-from dataclasses import dataclass
-
-# Disable Triton compilation
-os.environ["NO_TORCH_COMPILE"] = "1"
-
-# Default prompts are available at https://hf.co/sesame/csm-1b
-prompt_filepath_conversational_a = hf_hub_download(
- repo_id="sesame/csm-1b",
- filename="prompts/conversational_a.wav"
-)
-prompt_filepath_conversational_b = hf_hub_download(
- repo_id="sesame/csm-1b",
- filename="prompts/conversational_b.wav"
-)
-
-SPEAKER_PROMPTS = {
- "conversational_a": {
- "text": (
- "like revising for an exam I'd have to try and like keep up the momentum because I'd "
- "start really early I'd be like okay I'm gonna start revising now and then like "
- "you're revising for ages and then I just like start losing steam I didn't do that "
- "for the exam we had recently to be fair that was a more of a last minute scenario "
- "but like yeah I'm trying to like yeah I noticed this yesterday that like Mondays I "
- "sort of start the day with this not like a panic but like a"
- ),
- "audio": prompt_filepath_conversational_a
- },
- "conversational_b": {
- "text": (
- "like a super Mario level. Like it's very like high detail. And like, once you get "
- "into the park, it just like, everything looks like a computer game and they have all "
- "these, like, you know, if, if there's like a, you know, like in a Mario game, they "
- "will have like a question block. And if you like, you know, punch it, a coin will "
- "come out. So like everyone, when they come into the park, they get like this little "
- "bracelet and then you can go punching question blocks around."
- ),
- "audio": prompt_filepath_conversational_b
- }
-}
-
-def load_prompt_audio(audio_path: str, target_sample_rate: int) -> torch.Tensor:
- audio_tensor, sample_rate = torchaudio.load(audio_path)
- audio_tensor = audio_tensor.squeeze(0)
- # Resample is lazy so we can always call it
- audio_tensor = torchaudio.functional.resample(
- audio_tensor, orig_freq=sample_rate, new_freq=target_sample_rate
- )
- return audio_tensor
-
-def prepare_prompt(text: str, speaker: int, audio_path: str, sample_rate: int) -> Segment:
- audio_tensor = load_prompt_audio(audio_path, sample_rate)
- return Segment(text=text, speaker=speaker, audio=audio_tensor)
-
-def main():
- # Select the best available device, skipping MPS due to float64 limitations
- if torch.cuda.is_available():
- device = "cuda"
- else:
- device = "cpu"
- print(f"Using device: {device}")
-
- # Load model
- generator = load_csm_1b(device)
-
- # Prepare prompts
- prompt_a = prepare_prompt(
- SPEAKER_PROMPTS["conversational_a"]["text"],
- 0,
- SPEAKER_PROMPTS["conversational_a"]["audio"],
- generator.sample_rate
- )
-
- prompt_b = prepare_prompt(
- SPEAKER_PROMPTS["conversational_b"]["text"],
- 1,
- SPEAKER_PROMPTS["conversational_b"]["audio"],
- generator.sample_rate
- )
-
- # Generate conversation
- conversation = [
- {"text": "Hey how are you doing?", "speaker_id": 0},
- {"text": "Pretty good, pretty good. How about you?", "speaker_id": 1},
- {"text": "I'm great! So happy to be speaking with you today.", "speaker_id": 0},
- {"text": "Me too! This is some cool stuff, isn't it?", "speaker_id": 1}
- ]
-
- # Generate each utterance
- generated_segments = []
- prompt_segments = [prompt_a, prompt_b]
-
- for utterance in conversation:
- print(f"Generating: {utterance['text']}")
- audio_tensor = generator.generate(
- text=utterance['text'],
- speaker=utterance['speaker_id'],
- context=prompt_segments + generated_segments,
- max_audio_length_ms=10_000,
- )
- generated_segments.append(Segment(text=utterance['text'], speaker=utterance['speaker_id'], audio=audio_tensor))
-
- # Concatenate all generations
- all_audio = torch.cat([seg.audio for seg in generated_segments], dim=0)
- torchaudio.save(
- "full_conversation.wav",
- all_audio.unsqueeze(0).cpu(),
- generator.sample_rate
- )
- print("Successfully generated full_conversation.wav")
-
-if __name__ == "__main__":
- main()
\ No newline at end of file
diff --git a/Backend/server.py b/Backend/server.py
index e912a9d..b8af6b7 100644
--- a/Backend/server.py
+++ b/Backend/server.py
@@ -1,639 +1,19 @@
-import os
-import io
-import base64
-import time
-import json
-import uuid
-import logging
-import threading
-import queue
-import tempfile
-import gc
-from typing import Dict, List, Optional, Tuple
+"""
+CSM Voice Chat Server
+A voice chat application that uses CSM 1B for voice synthesis,
+WhisperX for speech recognition, and Llama 3.2 for language generation.
+"""
-import torch
-import torchaudio
-import numpy as np
-from flask import Flask, request, jsonify, send_from_directory
-from flask_socketio import SocketIO, emit
-from flask_cors import CORS
-from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
+# Start the Flask application
+from api.app import app, socketio
-# Import WhisperX for better transcription
-import whisperx
-
-from generator import load_csm_1b, Segment
-from dataclasses import dataclass
-
-# Add these imports at the top
-import psutil
-import gc
-
-# Configure logging
-logging.basicConfig(level=logging.INFO,
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
-logger = logging.getLogger(__name__)
-
-# Initialize Flask app
-app = Flask(__name__, static_folder='.')
-CORS(app)
-socketio = SocketIO(app, cors_allowed_origins="*", ping_timeout=120)
-
-# Configure device
-if torch.cuda.is_available():
- DEVICE = "cuda"
-elif torch.backends.mps.is_available():
- DEVICE = "mps"
-else:
- DEVICE = "cpu"
-
-logger.info(f"Using device: {DEVICE}")
-
-# Global variables
-active_conversations = {}
-user_queues = {}
-processing_threads = {}
-
-# Load models
-@dataclass
-class AppModels:
- generator = None
- tokenizer = None
- llm = None
- whisperx_model = None
- whisperx_align_model = None
- whisperx_align_metadata = None
- diarize_model = None
- last_language = None
-
-# Initialize the models object
-models = AppModels()
-
-def load_models():
- """Load all required models"""
- global models
-
- socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 0})
-
- # CSM 1B loading
- try:
- socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 10, 'message': 'Loading CSM voice model'})
- models.generator = load_csm_1b(device=DEVICE)
- logger.info("CSM 1B model loaded successfully")
- socketio.emit('model_status', {'model': 'csm', 'status': 'loaded'})
- socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 33})
- if DEVICE == "cuda":
- torch.cuda.empty_cache()
- except Exception as e:
- import traceback
- error_details = traceback.format_exc()
- logger.error(f"Error loading CSM 1B model: {str(e)}\n{error_details}")
- socketio.emit('model_status', {'model': 'csm', 'status': 'error', 'message': str(e)})
-
- # WhisperX loading
- try:
- socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 40, 'message': 'Loading speech recognition model'})
- # Use WhisperX for better transcription with timestamps
- import whisperx
-
- # Use compute_type based on device
- compute_type = "float16" if DEVICE == "cuda" else "float32"
-
- # Load the WhisperX model (smaller model for faster processing)
- models.whisperx_model = whisperx.load_model("small", DEVICE, compute_type=compute_type)
-
- logger.info("WhisperX model loaded successfully")
- socketio.emit('model_status', {'model': 'asr', 'status': 'loaded'})
- socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 66})
- if DEVICE == "cuda":
- torch.cuda.empty_cache()
- except Exception as e:
- import traceback
- error_details = traceback.format_exc()
- logger.error(f"Error loading WhisperX model: {str(e)}\n{error_details}")
- socketio.emit('model_status', {'model': 'asr', 'status': 'error', 'message': str(e)})
-
- # Llama loading
- try:
- socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 70, 'message': 'Loading language model'})
- models.llm = AutoModelForCausalLM.from_pretrained(
- "meta-llama/Llama-3.2-1B",
- device_map=DEVICE,
- torch_dtype=torch.bfloat16
- )
- models.tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B")
-
- # Configure all special tokens
- models.tokenizer.pad_token = models.tokenizer.eos_token
- models.tokenizer.padding_side = "left" # For causal language modeling
-
- # Inform the model about the pad token
- if hasattr(models.llm.config, "pad_token_id") and models.llm.config.pad_token_id is None:
- models.llm.config.pad_token_id = models.tokenizer.pad_token_id
-
- logger.info("Llama 3.2 model loaded successfully")
- socketio.emit('model_status', {'model': 'llm', 'status': 'loaded'})
- socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 100, 'message': 'All models loaded successfully'})
- socketio.emit('model_status', {'model': 'overall', 'status': 'loaded'})
- except Exception as e:
- logger.error(f"Error loading Llama 3.2 model: {str(e)}")
- socketio.emit('model_status', {'model': 'llm', 'status': 'error', 'message': str(e)})
-
-# Load models in a background thread
-threading.Thread(target=load_models, daemon=True).start()
-
-# Conversation data structure
-class Conversation:
- def __init__(self, session_id):
- self.session_id = session_id
- self.segments: List[Segment] = []
- self.current_speaker = 0
- self.ai_speaker_id = 1 # Add this property
- self.last_activity = time.time()
- self.is_processing = False
-
- def add_segment(self, text, speaker, audio):
- segment = Segment(text=text, speaker=speaker, audio=audio)
- self.segments.append(segment)
- self.last_activity = time.time()
- return segment
-
- def get_context(self, max_segments=10):
- """Return the most recent segments for context"""
- return self.segments[-max_segments:] if self.segments else []
-
-# Routes
-@app.route('/')
-def index():
- return send_from_directory('.', 'index.html')
-
-@app.route('/voice-chat.js')
-def voice_chat_js():
- return send_from_directory('.', 'voice-chat.js')
-
-@app.route('/api/health')
-def health_check():
- return jsonify({
- "status": "ok",
- "cuda_available": torch.cuda.is_available(),
- "models_loaded": models.generator is not None and models.llm is not None
- })
-
-# Fix the system_status function:
-
-@app.route('/api/status')
-def system_status():
- return jsonify({
- "status": "ok",
- "cuda_available": torch.cuda.is_available(),
- "device": DEVICE,
- "models": {
- "generator": models.generator is not None,
- "asr": models.whisperx_model is not None, # Use the correct model name
- "llm": models.llm is not None
- }
- })
-
-# Add a new endpoint to check system resources
-@app.route('/api/system_resources')
-def system_resources():
- # Get CPU usage
- cpu_percent = psutil.cpu_percent(interval=0.1)
-
- # Get memory usage
- memory = psutil.virtual_memory()
- memory_used_gb = memory.used / (1024 ** 3)
- memory_total_gb = memory.total / (1024 ** 3)
- memory_percent = memory.percent
-
- # Get GPU memory if available
- gpu_memory = {}
- if torch.cuda.is_available():
- for i in range(torch.cuda.device_count()):
- gpu_memory[f"gpu_{i}"] = {
- "allocated": torch.cuda.memory_allocated(i) / (1024 ** 3),
- "reserved": torch.cuda.memory_reserved(i) / (1024 ** 3),
- "max_allocated": torch.cuda.max_memory_allocated(i) / (1024 ** 3)
- }
-
- return jsonify({
- "cpu_percent": cpu_percent,
- "memory": {
- "used_gb": memory_used_gb,
- "total_gb": memory_total_gb,
- "percent": memory_percent
- },
- "gpu_memory": gpu_memory,
- "active_sessions": len(active_conversations)
- })
-
-# Socket event handlers
-@socketio.on('connect')
-def handle_connect(auth=None):
- session_id = request.sid
- logger.info(f"Client connected: {session_id}")
-
- # Initialize conversation data
- if session_id not in active_conversations:
- active_conversations[session_id] = Conversation(session_id)
- user_queues[session_id] = queue.Queue()
- processing_threads[session_id] = threading.Thread(
- target=process_audio_queue,
- args=(session_id, user_queues[session_id]),
- daemon=True
- )
- processing_threads[session_id].start()
-
- emit('connection_status', {'status': 'connected'})
-
-@socketio.on('disconnect')
-def handle_disconnect(reason=None):
- session_id = request.sid
- logger.info(f"Client disconnected: {session_id}. Reason: {reason}")
-
- # Cleanup
- if session_id in active_conversations:
- # Mark for deletion rather than immediately removing
- # as the processing thread might still be accessing it
- active_conversations[session_id].is_processing = False
- user_queues[session_id].put(None) # Signal thread to terminate
-
-@socketio.on('start_stream')
-def handle_start_stream():
- session_id = request.sid
- logger.info(f"Starting stream for client: {session_id}")
- emit('streaming_status', {'status': 'active'})
-
-@socketio.on('stop_stream')
-def handle_stop_stream():
- session_id = request.sid
- logger.info(f"Stopping stream for client: {session_id}")
- emit('streaming_status', {'status': 'inactive'})
-
-@socketio.on('clear_context')
-def handle_clear_context():
- session_id = request.sid
- logger.info(f"Clearing context for client: {session_id}")
-
- if session_id in active_conversations:
- active_conversations[session_id].segments = []
- emit('context_updated', {'status': 'cleared'})
-
-@socketio.on('audio_chunk')
-def handle_audio_chunk(data):
- session_id = request.sid
- audio_data = data.get('audio', '')
- speaker_id = int(data.get('speaker', 0))
-
- if not audio_data or not session_id in active_conversations:
- return
-
- # Update the current speaker
- active_conversations[session_id].current_speaker = speaker_id
-
- # Queue audio for processing
- user_queues[session_id].put({
- 'audio': audio_data,
- 'speaker': speaker_id
- })
-
- emit('processing_status', {'status': 'transcribing'})
-
-def process_audio_queue(session_id, q):
- """Background thread to process audio chunks for a session"""
- logger.info(f"Started processing thread for session: {session_id}")
-
- try:
- while session_id in active_conversations:
- try:
- # Get the next audio chunk with a timeout
- data = q.get(timeout=120)
- if data is None: # Termination signal
- break
-
- # Process the audio and generate a response
- process_audio_and_respond(session_id, data)
-
- except queue.Empty:
- # Timeout, check if session is still valid
- continue
- except Exception as e:
- logger.error(f"Error processing audio for {session_id}: {str(e)}")
- # Create an app context for the socket emit
- with app.app_context():
- socketio.emit('error', {'message': str(e)}, room=session_id)
- finally:
- logger.info(f"Ending processing thread for session: {session_id}")
- # Clean up when thread is done
- with app.app_context():
- if session_id in active_conversations:
- del active_conversations[session_id]
- if session_id in user_queues:
- del user_queues[session_id]
-
-def process_audio_and_respond(session_id, data):
- """Process audio data and generate a response using WhisperX"""
- if models.generator is None or models.whisperx_model is None or models.llm is None:
- logger.warning("Models not yet loaded!")
- with app.app_context():
- socketio.emit('error', {'message': 'Models still loading, please wait'}, room=session_id)
- return
-
- logger.info(f"Processing audio for session {session_id}")
- conversation = active_conversations[session_id]
-
- try:
- # Set processing flag
- conversation.is_processing = True
-
- # Process base64 audio data
- audio_data = data['audio']
- speaker_id = data['speaker']
- logger.info(f"Received audio from speaker {speaker_id}")
-
- # Convert from base64 to WAV
- try:
- audio_bytes = base64.b64decode(audio_data.split(',')[1])
- logger.info(f"Decoded audio bytes: {len(audio_bytes)} bytes")
- except Exception as e:
- logger.error(f"Error decoding base64 audio: {str(e)}")
- raise
-
- # Save to temporary file for processing
- with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as temp_file:
- temp_file.write(audio_bytes)
- temp_path = temp_file.name
-
- try:
- # Notify client that transcription is starting
- with app.app_context():
- socketio.emit('processing_status', {'status': 'transcribing'}, room=session_id)
-
- # Load audio using WhisperX
- import whisperx
- audio = whisperx.load_audio(temp_path)
-
- # Check audio length and add a warning for short clips
- audio_length = len(audio) / 16000 # assuming 16kHz sample rate
- if audio_length < 1.0:
- logger.warning(f"Audio is very short ({audio_length:.2f}s), may affect transcription quality")
-
- # Transcribe using WhisperX
- batch_size = 16 # adjust based on your GPU memory
- logger.info("Running WhisperX transcription...")
-
- # Handle the warning about audio being shorter than 30s by suppressing it
- import warnings
- with warnings.catch_warnings():
- warnings.filterwarnings("ignore", message="audio is shorter than 30s")
- result = models.whisperx_model.transcribe(audio, batch_size=batch_size)
-
- # Get the detected language
- language_code = result["language"]
- logger.info(f"Detected language: {language_code}")
-
- # Check if alignment model needs to be loaded or updated
- if models.whisperx_align_model is None or language_code != models.last_language:
- # Clean up old models if they exist
- if models.whisperx_align_model is not None:
- del models.whisperx_align_model
- del models.whisperx_align_metadata
- if DEVICE == "cuda":
- gc.collect()
- torch.cuda.empty_cache()
-
- # Load new alignment model for the detected language
- logger.info(f"Loading alignment model for language: {language_code}")
- models.whisperx_align_model, models.whisperx_align_metadata = whisperx.load_align_model(
- language_code=language_code, device=DEVICE
- )
- models.last_language = language_code
-
- # Align the transcript to get word-level timestamps
- if result["segments"] and len(result["segments"]) > 0:
- logger.info("Aligning transcript...")
- result = whisperx.align(
- result["segments"],
- models.whisperx_align_model,
- models.whisperx_align_metadata,
- audio,
- DEVICE,
- return_char_alignments=False
- )
-
- # Process the segments for better output
- for segment in result["segments"]:
- # Round timestamps for better display
- segment["start"] = round(segment["start"], 2)
- segment["end"] = round(segment["end"], 2)
- # Add a confidence score if not present
- if "confidence" not in segment:
- segment["confidence"] = 1.0 # Default confidence
-
- # Extract the full text from all segments
- user_text = ' '.join([segment['text'] for segment in result['segments']])
-
- # If no text was recognized, don't process further
- if not user_text or len(user_text.strip()) == 0:
- with app.app_context():
- socketio.emit('error', {'message': 'No speech detected'}, room=session_id)
- return
-
- logger.info(f"Transcription: {user_text}")
-
- # Load audio for CSM input
- waveform, sample_rate = torchaudio.load(temp_path)
-
- # Normalize to mono if needed
- if waveform.shape[0] > 1:
- waveform = torch.mean(waveform, dim=0, keepdim=True)
-
- # Resample to the CSM sample rate if needed
- if sample_rate != models.generator.sample_rate:
- waveform = torchaudio.functional.resample(
- waveform,
- orig_freq=sample_rate,
- new_freq=models.generator.sample_rate
- )
-
- # Add the user's message to conversation history
- user_segment = conversation.add_segment(
- text=user_text,
- speaker=speaker_id,
- audio=waveform.squeeze()
- )
-
- # Send transcription to client with detailed segments
- with app.app_context():
- socketio.emit('transcription', {
- 'text': user_text,
- 'speaker': speaker_id,
- 'segments': result['segments'] # Include the detailed segments with timestamps
- }, room=session_id)
-
- # Generate AI response using Llama
- with app.app_context():
- socketio.emit('processing_status', {'status': 'generating'}, room=session_id)
-
- # Create prompt from conversation history
- conversation_history = ""
- for segment in conversation.segments[-5:]: # Last 5 segments for context
- role = "User" if segment.speaker == 0 else "Assistant"
- conversation_history += f"{role}: {segment.text}\n"
-
- # Add final prompt
- prompt = f"{conversation_history}Assistant: "
-
- # Generate response with Llama
- try:
- # Ensure pad token is set
- if models.tokenizer.pad_token is None:
- models.tokenizer.pad_token = models.tokenizer.eos_token
-
- input_tokens = models.tokenizer(
- prompt,
- return_tensors="pt",
- padding=True,
- return_attention_mask=True
- )
- input_ids = input_tokens.input_ids.to(DEVICE)
- attention_mask = input_tokens.attention_mask.to(DEVICE)
-
- with torch.no_grad():
- generated_ids = models.llm.generate(
- input_ids,
- attention_mask=attention_mask,
- max_new_tokens=100,
- temperature=0.7,
- top_p=0.9,
- do_sample=True,
- pad_token_id=models.tokenizer.eos_token_id
- )
-
- # Decode the response
- response_text = models.tokenizer.decode(
- generated_ids[0][input_ids.shape[1]:],
- skip_special_tokens=True
- ).strip()
- except Exception as e:
- logger.error(f"Error generating response: {str(e)}")
- import traceback
- logger.error(traceback.format_exc())
- response_text = "I'm sorry, I encountered an error while processing your request."
-
- # Synthesize speech
- with app.app_context():
- socketio.emit('processing_status', {'status': 'synthesizing'}, room=session_id)
-
- # Start sending the audio response
- socketio.emit('audio_response_start', {
- 'text': response_text,
- 'total_chunks': 1,
- 'chunk_index': 0
- }, room=session_id)
-
- # Define AI speaker ID
- ai_speaker_id = conversation.ai_speaker_id
-
- # Generate audio
- audio_tensor = models.generator.generate(
- text=response_text,
- speaker=ai_speaker_id,
- context=conversation.get_context(),
- max_audio_length_ms=10_000,
- temperature=0.9
- )
-
- # Add AI response to conversation history
- ai_segment = conversation.add_segment(
- text=response_text,
- speaker=ai_speaker_id,
- audio=audio_tensor
- )
-
- # Convert audio to WAV format
- with io.BytesIO() as wav_io:
- torchaudio.save(
- wav_io,
- audio_tensor.unsqueeze(0).cpu(),
- models.generator.sample_rate,
- format="wav"
- )
- wav_io.seek(0)
- wav_data = wav_io.read()
-
- # Convert WAV data to base64
- audio_base64 = f"data:audio/wav;base64,{base64.b64encode(wav_data).decode('utf-8')}"
-
- # Send audio chunk to client
- with app.app_context():
- socketio.emit('audio_response_chunk', {
- 'chunk': audio_base64,
- 'chunk_index': 0,
- 'total_chunks': 1,
- 'is_last': True
- }, room=session_id)
-
- # Signal completion
- socketio.emit('audio_response_complete', {
- 'text': response_text
- }, room=session_id)
-
- finally:
- # Clean up temp file
- if os.path.exists(temp_path):
- os.unlink(temp_path)
-
- except Exception as e:
- logger.error(f"Error processing audio: {str(e)}")
- import traceback
- logger.error(traceback.format_exc())
- with app.app_context():
- socketio.emit('error', {'message': f'Error: {str(e)}'}, room=session_id)
- finally:
- # Reset processing flag
- conversation.is_processing = False
-
-# Error handler
-@socketio.on_error()
-def error_handler(e):
- logger.error(f"SocketIO error: {str(e)}")
-
-# Periodic cleanup of inactive sessions
-def cleanup_inactive_sessions():
- """Remove sessions that have been inactive for too long"""
- current_time = time.time()
- inactive_timeout = 3600 # 1 hour
-
- for session_id in list(active_conversations.keys()):
- conversation = active_conversations[session_id]
- if (current_time - conversation.last_activity > inactive_timeout and
- not conversation.is_processing):
-
- logger.info(f"Cleaning up inactive session: {session_id}")
-
- # Signal processing thread to terminate
- if session_id in user_queues:
- user_queues[session_id].put(None)
-
- # Remove from active conversations
- del active_conversations[session_id]
-
-# Start cleanup thread
-def start_cleanup_thread():
- while True:
- try:
- cleanup_inactive_sessions()
- except Exception as e:
- logger.error(f"Error in cleanup thread: {str(e)}")
- time.sleep(300) # Run every 5 minutes
-
-cleanup_thread = threading.Thread(target=start_cleanup_thread, daemon=True)
-cleanup_thread.start()
-
-# Start the server
if __name__ == '__main__':
+ import os
+
port = int(os.environ.get('PORT', 5000))
debug_mode = os.environ.get('DEBUG', 'False').lower() == 'true'
- logger.info(f"Starting server on port {port} (debug={debug_mode})")
+
+ print(f"Starting server on port {port} (debug={debug_mode})")
+ print("Visit http://localhost:5000 to access the application")
+
socketio.run(app, host='0.0.0.0', port=port, debug=debug_mode, allow_unsafe_werkzeug=True)
\ No newline at end of file
diff --git a/Backend/setup.py b/Backend/setup.py
deleted file mode 100644
index 8eddb95..0000000
--- a/Backend/setup.py
+++ /dev/null
@@ -1,13 +0,0 @@
-from setuptools import setup, find_packages
-import os
-
-# Read requirements from requirements.txt
-with open('requirements.txt') as f:
- requirements = [line.strip() for line in f if line.strip() and not line.startswith('#')]
-
-setup(
- name='csm',
- version='0.1.0',
- packages=find_packages(),
- install_requires=requirements,
-)
diff --git a/Backend/voice-chat.js b/Backend/voice-chat.js
deleted file mode 100644
index dc2db04..0000000
--- a/Backend/voice-chat.js
+++ /dev/null
@@ -1,1054 +0,0 @@
-/**
- * CSM AI Voice Chat Client
- *
- * A web client that connects to a CSM AI voice chat server and enables
- * real-time voice conversation with an AI assistant.
- */
-
-// Configuration constants
-const SERVER_URL = window.location.hostname === 'localhost' ?
- 'http://localhost:5000' : window.location.origin;
-const ENERGY_WINDOW_SIZE = 15;
-const CLIENT_SILENCE_DURATION_MS = 750;
-
-// DOM elements
-const elements = {
- conversation: document.getElementById('conversation'),
- streamButton: document.getElementById('streamButton'),
- clearButton: document.getElementById('clearButton'),
- thresholdSlider: document.getElementById('thresholdSlider'),
- thresholdValue: document.getElementById('thresholdValue'),
- visualizerCanvas: document.getElementById('audioVisualizer'),
- visualizerLabel: document.getElementById('visualizerLabel'),
- volumeLevel: document.getElementById('volumeLevel'),
- statusDot: document.getElementById('statusDot'),
- statusText: document.getElementById('statusText'),
- speakerSelection: document.getElementById('speakerSelect'),
- autoPlayResponses: document.getElementById('autoPlayResponses'),
- showVisualizer: document.getElementById('showVisualizer')
-};
-
-// Application state
-const state = {
- socket: null,
- audioContext: null,
- analyser: null,
- microphone: null,
- streamProcessor: null,
- isStreaming: false,
- isSpeaking: false,
- silenceThreshold: 0.01,
- energyWindow: [],
- silenceTimer: null,
- volumeUpdateInterval: null,
- visualizerAnimationFrame: null,
- currentSpeaker: 0,
- aiSpeakerId: 1, // Define the AI's speaker ID to match server.py
- transcriptionRetries: 0,
- maxTranscriptionRetries: 3
-};
-
-// Visualizer variables
-let canvasContext = null;
-let visualizerBufferLength = 0;
-let visualizerDataArray = null;
-
-// Audio streaming state
-const streamingAudio = {
- messageElement: null,
- audioElement: null,
- chunks: [],
- totalChunks: 0,
- receivedChunks: 0,
- text: '',
- complete: false
-};
-
-// Initialize the application
-function initializeApp() {
- // Initialize the UI elements
- initializeUIElements();
-
- // Initialize socket.io connection
- setupSocketConnection();
-
- // Setup event listeners
- setupEventListeners();
-
- // Initialize visualizer
- setupVisualizer();
-
- // Show welcome message
- addSystemMessage('Welcome to CSM Voice Chat! Click "Start Conversation" to begin.');
-}
-
-// Initialize UI elements
-function initializeUIElements() {
- // Update threshold display
- if (elements.thresholdValue) {
- elements.thresholdValue.textContent = state.silenceThreshold.toFixed(3);
- }
-}
-
-// Setup Socket.IO connection
-function setupSocketConnection() {
- state.socket = io(SERVER_URL);
-
- // Connection events
- state.socket.on('connect', () => {
- updateConnectionStatus(true);
- addSystemMessage('Connected to server.');
- });
-
- state.socket.on('disconnect', () => {
- updateConnectionStatus(false);
- addSystemMessage('Disconnected from server.');
- stopStreaming(false);
- });
-
- state.socket.on('error', (data) => {
- console.error('Server error:', data.message);
-
- // Make the error more user-friendly
- let userMessage = data.message;
-
- // Check for common errors and provide more helpful messages
- if (data.message.includes('Models still loading')) {
- userMessage = 'The AI models are still loading. Please wait a moment and try again.';
- } else if (data.message.includes('No speech detected')) {
- userMessage = 'No speech was detected. Please speak clearly and try again.';
- }
-
- addSystemMessage(`Error: ${userMessage}`);
-
- // Reset button state if it was processing
- if (elements.streamButton.classList.contains('processing')) {
- elements.streamButton.classList.remove('processing');
- elements.streamButton.innerHTML = ' Start Conversation';
- }
- });
-
- // Register message handlers
- state.socket.on('transcription', handleTranscription);
- state.socket.on('context_updated', handleContextUpdate);
- state.socket.on('streaming_status', handleStreamingStatus);
- state.socket.on('processing_status', handleProcessingStatus);
-
- // Add model status handlers
- state.socket.on('model_status', handleModelStatusUpdate);
-
- // Handlers for incremental audio streaming
- state.socket.on('audio_response_start', handleAudioResponseStart);
- state.socket.on('audio_response_chunk', handleAudioResponseChunk);
- state.socket.on('audio_response_complete', handleAudioResponseComplete);
-}
-
-// Setup event listeners
-function setupEventListeners() {
- // Stream button
- elements.streamButton.addEventListener('click', toggleStreaming);
-
- // Clear button
- elements.clearButton.addEventListener('click', clearConversation);
-
- // Threshold slider
- if (elements.thresholdSlider) {
- elements.thresholdSlider.addEventListener('input', updateThreshold);
- }
-
- // Speaker selection
- elements.speakerSelection.addEventListener('change', () => {
- state.currentSpeaker = parseInt(elements.speakerSelection.value);
- });
-
- // Visualizer toggle
- if (elements.showVisualizer) {
- elements.showVisualizer.addEventListener('change', toggleVisualizerVisibility);
- }
-}
-
-// Setup audio visualizer
-function setupVisualizer() {
- if (!elements.visualizerCanvas) return;
-
- canvasContext = elements.visualizerCanvas.getContext('2d');
-
- // Set canvas dimensions
- elements.visualizerCanvas.width = elements.visualizerCanvas.offsetWidth;
- elements.visualizerCanvas.height = elements.visualizerCanvas.offsetHeight;
-
- // Initialize visualization data array
- visualizerDataArray = new Uint8Array(128);
-
- // Start the visualizer animation
- drawVisualizer();
-}
-
-// Update connection status UI
-function updateConnectionStatus(isConnected) {
- if (isConnected) {
- elements.statusDot.classList.add('active');
- elements.statusText.textContent = 'Connected';
- } else {
- elements.statusDot.classList.remove('active');
- elements.statusText.textContent = 'Disconnected';
- }
-}
-
-// Toggle streaming state
-function toggleStreaming() {
- if (state.isStreaming) {
- stopStreaming();
- } else {
- startStreaming();
- }
-}
-
-// Start streaming audio to the server
-function startStreaming() {
- if (!state.socket || !state.socket.connected) {
- addSystemMessage('Not connected to server. Please refresh the page.');
- return;
- }
-
- // Check if models are loaded via the API
- fetch('/api/status')
- .then(response => response.json())
- .then(data => {
- if (!data.models.generator || !data.models.asr || !data.models.llm) {
- addSystemMessage('Still loading AI models. Please wait...');
- return;
- }
-
- // Continue with recording if models are loaded
- initializeRecording();
- })
- .catch(error => {
- console.error('Error checking model status:', error);
- // Try anyway, the server will respond with an error if models aren't ready
- initializeRecording();
- });
-}
-
-// Extracted the recording initialization to a separate function
-function initializeRecording() {
- // Request microphone access
- navigator.mediaDevices.getUserMedia({ audio: true, video: false })
- .then(stream => {
- state.isStreaming = true;
- elements.streamButton.classList.add('recording');
- elements.streamButton.innerHTML = ' Stop Recording';
-
- // Initialize Web Audio API
- state.audioContext = new (window.AudioContext || window.webkitAudioContext)({ sampleRate: 16000 });
- state.microphone = state.audioContext.createMediaStreamSource(stream);
- state.analyser = state.audioContext.createAnalyser();
- state.analyser.fftSize = 2048;
-
- // Setup analyzer for visualizer
- visualizerBufferLength = state.analyser.frequencyBinCount;
- visualizerDataArray = new Uint8Array(visualizerBufferLength);
-
- state.microphone.connect(state.analyser);
-
- // Create processor node for audio data
- const processorNode = state.audioContext.createScriptProcessor(4096, 1, 1);
- processorNode.onaudioprocess = handleAudioProcess;
- state.analyser.connect(processorNode);
- processorNode.connect(state.audioContext.destination);
- state.streamProcessor = processorNode;
-
- state.silenceTimer = null;
- state.energyWindow = [];
- state.isSpeaking = false;
-
- // Notify server
- state.socket.emit('start_stream');
-
- // Start volume meter updates
- state.volumeUpdateInterval = setInterval(updateVolumeMeter, 100);
-
- // Make sure visualizer is visible if enabled
- if (elements.showVisualizer && elements.showVisualizer.checked) {
- elements.visualizerLabel.style.opacity = '0';
- }
-
- addSystemMessage('Recording started. Speak now...');
- })
- .catch(error => {
- console.error('Error accessing microphone:', error);
- addSystemMessage('Could not access microphone. Please check permissions.');
- });
-}
-
-// Stop streaming audio
-function stopStreaming(notifyServer = true) {
- if (state.isStreaming) {
- state.isStreaming = false;
- elements.streamButton.classList.remove('recording');
- elements.streamButton.classList.remove('processing');
- elements.streamButton.innerHTML = ' Start Conversation';
-
- // Clean up audio resources
- if (state.streamProcessor) {
- state.streamProcessor.disconnect();
- state.streamProcessor = null;
- }
-
- if (state.analyser) {
- state.analyser.disconnect();
- state.analyser = null;
- }
-
- if (state.microphone) {
- state.microphone.disconnect();
- state.microphone = null;
- }
-
- if (state.audioContext) {
- state.audioContext.close().catch(err => console.warn('Error closing audio context:', err));
- state.audioContext = null;
- }
-
- // Clear any pending silence timer
- if (state.silenceTimer) {
- clearTimeout(state.silenceTimer);
- state.silenceTimer = null;
- }
-
- // Clear volume meter updates
- if (state.volumeUpdateInterval) {
- clearInterval(state.volumeUpdateInterval);
- state.volumeUpdateInterval = null;
-
- // Reset volume meter
- if (elements.volumeLevel) {
- elements.volumeLevel.style.width = '0%';
- }
- }
-
- // Show visualizer label
- if (elements.visualizerLabel) {
- elements.visualizerLabel.style.opacity = '0.7';
- }
-
- // Notify server if needed
- if (notifyServer && state.socket && state.socket.connected) {
- state.socket.emit('stop_stream');
- }
-
- addSystemMessage('Recording stopped.');
- }
-}
-
-// Handle audio processing
-function handleAudioProcess(event) {
- if (!state.isStreaming) return;
-
- const inputData = event.inputBuffer.getChannelData(0);
- const energy = calculateAudioEnergy(inputData);
- updateEnergyWindow(energy);
-
- const averageEnergy = calculateAverageEnergy();
- const isSilent = averageEnergy < state.silenceThreshold;
-
- handleSpeechState(isSilent);
-}
-
-// Calculate audio energy (volume)
-function calculateAudioEnergy(buffer) {
- let sum = 0;
- for (let i = 0; i < buffer.length; i++) {
- sum += buffer[i] * buffer[i];
- }
- return Math.sqrt(sum / buffer.length);
-}
-
-// Update energy window for averaging
-function updateEnergyWindow(energy) {
- state.energyWindow.push(energy);
- if (state.energyWindow.length > ENERGY_WINDOW_SIZE) {
- state.energyWindow.shift();
- }
-}
-
-// Calculate average energy from window
-function calculateAverageEnergy() {
- if (state.energyWindow.length === 0) return 0;
-
- const sum = state.energyWindow.reduce((acc, val) => acc + val, 0);
- return sum / state.energyWindow.length;
-}
-
-// Update the threshold from the slider
-function updateThreshold() {
- state.silenceThreshold = parseFloat(elements.thresholdSlider.value);
- elements.thresholdValue.textContent = state.silenceThreshold.toFixed(3);
-}
-
-// Update the volume meter display
-function updateVolumeMeter() {
- if (!state.isStreaming || !state.energyWindow.length || !elements.volumeLevel) return;
-
- const avgEnergy = calculateAverageEnergy();
-
- // Scale energy to percentage (0-100)
- // Energy values are typically very small (e.g., 0.001 to 0.1)
- const scaleFactor = 1000;
- const percentage = Math.min(100, Math.max(0, avgEnergy * scaleFactor));
-
- // Update volume meter width
- elements.volumeLevel.style.width = `${percentage}%`;
-
- // Change color based on level
- if (percentage > 70) {
- elements.volumeLevel.style.backgroundColor = '#ff5252';
- } else if (percentage > 30) {
- elements.volumeLevel.style.backgroundColor = '#4CAF50';
- } else {
- elements.volumeLevel.style.backgroundColor = '#4c84ff';
- }
-}
-
-// Handle speech/silence state transitions
-function handleSpeechState(isSilent) {
- if (state.isSpeaking) {
- if (isSilent) {
- // User was speaking but now is silent
- if (!state.silenceTimer) {
- state.silenceTimer = setTimeout(() => {
- // Silence lasted long enough, consider speech done
- if (state.isSpeaking) {
- state.isSpeaking = false;
-
- try {
- // Get the current audio data and send it
- const audioBuffer = new Float32Array(state.audioContext.sampleRate * 5); // 5 seconds max
- state.analyser.getFloatTimeDomainData(audioBuffer);
-
- // Check if audio has content
- const hasAudioContent = audioBuffer.some(sample => Math.abs(sample) > 0.01);
-
- if (!hasAudioContent) {
- console.warn('Audio buffer appears to be empty or very quiet');
-
- if (state.transcriptionRetries < state.maxTranscriptionRetries) {
- state.transcriptionRetries++;
- const retryMessage = `No speech detected (attempt ${state.transcriptionRetries}/${state.maxTranscriptionRetries}). Please speak louder and try again.`;
- addSystemMessage(retryMessage);
- } else {
- state.transcriptionRetries = 0;
- addSystemMessage('Multiple attempts failed to detect speech. Please check your microphone and try again.');
- }
- return;
- }
-
- // Create WAV blob
- const wavBlob = createWavBlob(audioBuffer, state.audioContext.sampleRate);
-
- // Convert to base64
- const reader = new FileReader();
- reader.onloadend = function() {
- sendAudioChunk(reader.result, state.currentSpeaker);
- };
- reader.readAsDataURL(wavBlob);
-
- // Update button state
- elements.streamButton.classList.add('processing');
- elements.streamButton.innerHTML = ' Processing...';
-
- addSystemMessage('Processing your message...');
- } catch (e) {
- console.error('Error recording audio:', e);
- addSystemMessage('Error recording audio. Please try again.');
- }
- }
- }, CLIENT_SILENCE_DURATION_MS);
- }
- } else {
- // User is still speaking, reset silence timer
- if (state.silenceTimer) {
- clearTimeout(state.silenceTimer);
- state.silenceTimer = null;
- }
- }
- } else {
- if (!isSilent) {
- // User started speaking
- state.isSpeaking = true;
- if (state.silenceTimer) {
- clearTimeout(state.silenceTimer);
- state.silenceTimer = null;
- }
- }
- }
-}
-
-// Send audio chunk to server
-function sendAudioChunk(audioData, speaker) {
- if (state.socket && state.socket.connected) {
- state.socket.emit('audio_chunk', {
- audio: audioData,
- speaker: speaker
- });
- }
-}
-
-// Create WAV blob from audio data
-function createWavBlob(audioData, sampleRate) {
- const numChannels = 1;
- const bitsPerSample = 16;
- const bytesPerSample = bitsPerSample / 8;
-
- // Create buffer for WAV file
- const buffer = new ArrayBuffer(44 + audioData.length * bytesPerSample);
- const view = new DataView(buffer);
-
- // Write WAV header
- // "RIFF" chunk descriptor
- writeString(view, 0, 'RIFF');
- view.setUint32(4, 36 + audioData.length * bytesPerSample, true);
- writeString(view, 8, 'WAVE');
-
- // "fmt " sub-chunk
- writeString(view, 12, 'fmt ');
- view.setUint32(16, 16, true); // subchunk1size
- view.setUint16(20, 1, true); // audio format (PCM)
- view.setUint16(22, numChannels, true);
- view.setUint32(24, sampleRate, true);
- view.setUint32(28, sampleRate * numChannels * bytesPerSample, true); // byte rate
- view.setUint16(32, numChannels * bytesPerSample, true); // block align
- view.setUint16(34, bitsPerSample, true);
-
- // "data" sub-chunk
- writeString(view, 36, 'data');
- view.setUint32(40, audioData.length * bytesPerSample, true);
-
- // Write audio data
- const audioDataStart = 44;
- for (let i = 0; i < audioData.length; i++) {
- const sample = Math.max(-1, Math.min(1, audioData[i]));
- const value = sample < 0 ? sample * 0x8000 : sample * 0x7FFF;
- view.setInt16(audioDataStart + i * bytesPerSample, value, true);
- }
-
- return new Blob([buffer], { type: 'audio/wav' });
-}
-
-// Helper function to write strings to DataView
-function writeString(view, offset, string) {
- for (let i = 0; i < string.length; i++) {
- view.setUint8(offset + i, string.charCodeAt(i));
- }
-}
-
-// Clear conversation history
-function clearConversation() {
- elements.conversation.innerHTML = '';
- if (state.socket && state.socket.connected) {
- state.socket.emit('clear_context');
- }
- addSystemMessage('Conversation cleared.');
-}
-
-// Draw audio visualizer
-function drawVisualizer() {
- if (!canvasContext || !elements.visualizerCanvas) {
- state.visualizerAnimationFrame = requestAnimationFrame(drawVisualizer);
- return;
- }
-
- state.visualizerAnimationFrame = requestAnimationFrame(drawVisualizer);
-
- // Skip drawing if visualizer is hidden or not enabled
- if (elements.showVisualizer && !elements.showVisualizer.checked) {
- if (elements.visualizerCanvas.style.opacity !== '0') {
- elements.visualizerCanvas.style.opacity = '0';
- }
- return;
- } else if (elements.visualizerCanvas.style.opacity !== '1') {
- elements.visualizerCanvas.style.opacity = '1';
- }
-
- // Get frequency data if available
- if (state.isStreaming && state.analyser) {
- try {
- state.analyser.getByteFrequencyData(visualizerDataArray);
- } catch (e) {
- console.warn('Error getting frequency data:', e);
- }
- } else {
- // Fade out when not streaming
- for (let i = 0; i < visualizerDataArray.length; i++) {
- visualizerDataArray[i] = Math.max(0, visualizerDataArray[i] - 5);
- }
- }
-
- // Clear canvas
- canvasContext.fillStyle = 'rgb(0, 0, 0)';
- canvasContext.fillRect(0, 0, elements.visualizerCanvas.width, elements.visualizerCanvas.height);
-
- // Draw gradient bars
- const width = elements.visualizerCanvas.width;
- const height = elements.visualizerCanvas.height;
- const barCount = Math.min(visualizerBufferLength, 64);
- const barWidth = width / barCount - 1;
-
- for (let i = 0; i < barCount; i++) {
- const index = Math.floor(i * visualizerBufferLength / barCount);
- const value = visualizerDataArray[index];
-
- // Use logarithmic scale for better audio visualization
- const logFactor = 20;
- const scaledValue = Math.log(1 + (value / 255) * logFactor) / Math.log(1 + logFactor);
- const barHeight = scaledValue * height;
-
- // Position bars
- const x = i * (barWidth + 1);
- const y = height - barHeight;
-
- // Create color gradient based on frequency and amplitude
- const hue = i / barCount * 360; // Full color spectrum
- const saturation = 80 + (value / 255 * 20); // Higher values more saturated
- const lightness = 40 + (value / 255 * 20); // Dynamic brightness
-
- // Draw main bar
- canvasContext.fillStyle = `hsl(${hue}, ${saturation}%, ${lightness}%)`;
- canvasContext.fillRect(x, y, barWidth, barHeight);
-
- // Add highlight effect
- if (barHeight > 5) {
- const gradient = canvasContext.createLinearGradient(
- x, y,
- x, y + barHeight * 0.5
- );
- gradient.addColorStop(0, `hsla(${hue}, ${saturation}%, ${lightness + 20}%, 0.4)`);
- gradient.addColorStop(1, `hsla(${hue}, ${saturation}%, ${lightness}%, 0)`);
- canvasContext.fillStyle = gradient;
- canvasContext.fillRect(x, y, barWidth, barHeight * 0.5);
-
- // Add highlight on top of the bar
- canvasContext.fillStyle = `hsla(${hue}, ${saturation - 20}%, ${lightness + 30}%, 0.7)`;
- canvasContext.fillRect(x, y, barWidth, 2);
- }
- }
-}
-
-// Toggle visualizer visibility
-function toggleVisualizerVisibility() {
- const isVisible = elements.showVisualizer.checked;
- elements.visualizerCanvas.style.opacity = isVisible ? '1' : '0';
-}
-
-// Add a message to the conversation
-function addMessage(text, type) {
- if (!elements.conversation) return;
-
- const messageDiv = document.createElement('div');
- messageDiv.className = `message ${type}`;
-
- const textElement = document.createElement('p');
- textElement.textContent = text;
- messageDiv.appendChild(textElement);
-
- // Add timestamp to every message
- const timestamp = new Date().toLocaleTimeString();
- const timeLabel = document.createElement('div');
- timeLabel.className = 'message-timestamp';
- timeLabel.textContent = timestamp;
- messageDiv.appendChild(timeLabel);
-
- elements.conversation.appendChild(messageDiv);
-
- // Auto-scroll to the bottom
- elements.conversation.scrollTop = elements.conversation.scrollHeight;
-
- return messageDiv;
-}
-
-// Add a system message to the conversation
-function addSystemMessage(text) {
- if (!elements.conversation) return;
-
- const messageDiv = document.createElement('div');
- messageDiv.className = 'message system';
- messageDiv.textContent = text;
-
- elements.conversation.appendChild(messageDiv);
-
- // Auto-scroll to the bottom
- elements.conversation.scrollTop = elements.conversation.scrollHeight;
-
- return messageDiv;
-}
-
-// Handle transcription response from server
-function handleTranscription(data) {
- const speaker = data.speaker === 0 ? 'user' : 'ai';
-
- // Create the message div
- const messageDiv = addMessage(data.text, speaker);
-
- // If we have detailed segments from WhisperX, add timestamps
- if (data.segments && data.segments.length > 0) {
- // Add a timestamps container
- const timestampsContainer = document.createElement('div');
- timestampsContainer.className = 'timestamps-container';
- timestampsContainer.style.display = 'none'; // Hidden by default
-
- // Add a toggle button
- const toggleButton = document.createElement('button');
- toggleButton.className = 'timestamp-toggle';
- toggleButton.textContent = 'Show Timestamps';
- toggleButton.onclick = function() {
- const isHidden = timestampsContainer.style.display === 'none';
- timestampsContainer.style.display = isHidden ? 'block' : 'none';
- toggleButton.textContent = isHidden ? 'Hide Timestamps' : 'Show Timestamps';
- };
-
- // Add timestamps for each segment
- data.segments.forEach(segment => {
- const timestampDiv = document.createElement('div');
- timestampDiv.className = 'timestamp';
-
- // Format start and end times
- const startTime = formatTime(segment.start);
- const endTime = formatTime(segment.end);
-
- timestampDiv.innerHTML = `
- [${startTime} - ${endTime}]
- ${segment.text}
- `;
-
- timestampsContainer.appendChild(timestampDiv);
- });
-
- // Add the timestamp elements to the message
- messageDiv.appendChild(toggleButton);
- messageDiv.appendChild(timestampsContainer);
- } else {
- // No timestamp data available - add a simple timestamp for the entire message
- const timestamp = new Date().toLocaleTimeString();
- const timeLabel = document.createElement('div');
- timeLabel.className = 'simple-timestamp';
- timeLabel.textContent = timestamp;
- messageDiv.appendChild(timeLabel);
- }
-
- return messageDiv;
-}
-
-// Helper function to format time in seconds to MM:SS.ms format
-function formatTime(seconds) {
- const mins = Math.floor(seconds / 60);
- const secs = Math.floor(seconds % 60);
- const ms = Math.floor((seconds % 1) * 1000);
-
- return `${mins.toString().padStart(2, '0')}:${secs.toString().padStart(2, '0')}.${ms.toString().padStart(3, '0')}`;
-}
-
-// Handle context update from server
-function handleContextUpdate(data) {
- if (data.status === 'cleared') {
- elements.conversation.innerHTML = '';
- addSystemMessage('Conversation context cleared.');
- }
-}
-
-// Handle streaming status updates from server
-function handleStreamingStatus(data) {
- if (data.status === 'active') {
- console.log('Server acknowledged streaming is active');
- } else if (data.status === 'inactive') {
- console.log('Server acknowledged streaming is inactive');
- }
-}
-
-// Handle processing status updates
-function handleProcessingStatus(data) {
- switch (data.status) {
- case 'transcribing':
- addSystemMessage('Transcribing your message...');
- break;
- case 'generating':
- addSystemMessage('Generating response...');
- break;
- case 'synthesizing':
- addSystemMessage('Synthesizing voice...');
- break;
- }
-}
-
-// Handle the start of an audio streaming response
-function handleAudioResponseStart(data) {
- console.log(`Expecting ${data.total_chunks} audio chunks`);
-
- // Reset streaming state
- streamingAudio.chunks = [];
- streamingAudio.totalChunks = data.total_chunks;
- streamingAudio.receivedChunks = 0;
- streamingAudio.text = data.text;
- streamingAudio.complete = false;
-}
-
-// Handle an incoming audio chunk
-function handleAudioResponseChunk(data) {
- // Create or update audio element for playback
- const audioElement = document.createElement('audio');
- if (elements.autoPlayResponses.checked) {
- audioElement.autoplay = true;
- }
- audioElement.controls = true;
- audioElement.className = 'audio-player';
- audioElement.src = data.chunk;
-
- // Store the chunk
- streamingAudio.chunks[data.chunk_index] = data.chunk;
- streamingAudio.receivedChunks++;
-
- // Store audio element reference for later use
- streamingAudio.audioElement = audioElement;
-
- // Add to the conversation
- const messages = elements.conversation.querySelectorAll('.message.ai');
- if (messages.length > 0) {
- const lastAiMessage = messages[messages.length - 1];
- streamingAudio.messageElement = lastAiMessage;
-
- // Replace existing audio player if there is one
- const existingPlayer = lastAiMessage.querySelector('.audio-player');
- if (existingPlayer) {
- lastAiMessage.replaceChild(audioElement, existingPlayer);
- } else {
- lastAiMessage.appendChild(audioElement);
- }
- } else {
- // Create a new message for the AI response
- const aiMessage = document.createElement('div');
- aiMessage.className = 'message ai';
- streamingAudio.messageElement = aiMessage;
-
- if (streamingAudio.text) {
- const textElement = document.createElement('p');
- textElement.textContent = streamingAudio.text;
- aiMessage.appendChild(textElement);
- }
-
- aiMessage.appendChild(audioElement);
- elements.conversation.appendChild(aiMessage);
- }
-
- // Auto-scroll
- elements.conversation.scrollTop = elements.conversation.scrollHeight;
-
- // If this is the last chunk or we've received all expected chunks
- if (data.is_last || streamingAudio.receivedChunks >= streamingAudio.totalChunks) {
- streamingAudio.complete = true;
-
- // Reset stream button if we're still streaming
- if (state.isStreaming) {
- elements.streamButton.classList.remove('processing');
- elements.streamButton.innerHTML = ' Listening...';
- }
- }
-}
-
-// Handle completion of audio streaming
-function handleAudioResponseComplete(data) {
- console.log('Audio response complete:', data);
- streamingAudio.complete = true;
-
- // Make sure we finalize the audio even if some chunks were missed
- finalizeStreamingAudio();
-
- // Update UI to normal state
- if (state.isStreaming) {
- elements.streamButton.innerHTML = ' Listening...';
- elements.streamButton.classList.add('recording');
- elements.streamButton.classList.remove('processing');
- }
-}
-
-// Finalize streaming audio by combining chunks and updating the UI
-function finalizeStreamingAudio() {
- if (!streamingAudio.messageElement || streamingAudio.chunks.length === 0) {
- return;
- }
-
- try {
- // For more sophisticated audio streaming, you would need to properly concatenate
- // the WAV files, but for now we'll use the last chunk as the complete audio
- // since it should contain the entire response due to how the server is implementing it
- const lastChunkIndex = streamingAudio.chunks.length - 1;
- const audioData = streamingAudio.chunks[lastChunkIndex] || streamingAudio.chunks[0];
-
- // Update the audio element with the complete audio
- if (streamingAudio.audioElement) {
- streamingAudio.audioElement.src = audioData;
-
- // Auto-play if enabled and not already playing
- if (elements.autoPlayResponses && elements.autoPlayResponses.checked &&
- streamingAudio.audioElement.paused) {
- streamingAudio.audioElement.play()
- .catch(err => {
- console.warn('Auto-play failed:', err);
- addSystemMessage('Auto-play failed. Please click play to hear the response.');
- });
- }
- }
-
- // Remove loading indicator and processing class
- if (streamingAudio.messageElement) {
- const loadingElement = streamingAudio.messageElement.querySelector('.loading-indicator');
- if (loadingElement) {
- streamingAudio.messageElement.removeChild(loadingElement);
- }
- streamingAudio.messageElement.classList.remove('processing');
- }
-
- console.log('Audio response finalized and ready for playback');
- } catch (e) {
- console.error('Error finalizing streaming audio:', e);
- }
-
- // Reset streaming audio state
- streamingAudio.chunks = [];
- streamingAudio.totalChunks = 0;
- streamingAudio.receivedChunks = 0;
- streamingAudio.messageElement = null;
- streamingAudio.audioElement = null;
-}
-
-// Enhance the handleModelStatusUpdate function:
-
-function handleModelStatusUpdate(data) {
- const { model, status, message, progress } = data;
-
- if (model === 'overall' && status === 'loading') {
- // Update overall loading progress
- const progressBar = document.getElementById('modelLoadingProgress');
- if (progressBar) {
- progressBar.value = progress;
- progressBar.textContent = `${progress}%`;
- }
- return;
- }
-
- if (status === 'loaded') {
- console.log(`Model ${model} loaded successfully`);
- addSystemMessage(`${model.toUpperCase()} model loaded successfully`);
-
- // Update UI to show model is ready
- const modelStatusElement = document.getElementById(`${model}Status`);
- if (modelStatusElement) {
- modelStatusElement.classList.remove('loading');
- modelStatusElement.classList.add('loaded');
- modelStatusElement.title = 'Model loaded successfully';
- }
-
- // Check if the required models are loaded to enable conversation
- checkAllModelsLoaded();
- } else if (status === 'error') {
- console.error(`Error loading ${model} model: ${message}`);
- addSystemMessage(`Error loading ${model.toUpperCase()} model: ${message}`);
-
- // Update UI to show model loading failed
- const modelStatusElement = document.getElementById(`${model}Status`);
- if (modelStatusElement) {
- modelStatusElement.classList.remove('loading');
- modelStatusElement.classList.add('error');
- modelStatusElement.title = `Error: ${message}`;
- }
- }
-}
-
-// Check if all required models are loaded and enable UI accordingly
-function checkAllModelsLoaded() {
- // When all models are loaded, enable the stream button if it was disabled
- const allLoaded =
- document.getElementById('csmStatus')?.classList.contains('loaded') &&
- document.getElementById('asrStatus')?.classList.contains('loaded') &&
- document.getElementById('llmStatus')?.classList.contains('loaded');
-
- if (allLoaded) {
- elements.streamButton.disabled = false;
- addSystemMessage('All models loaded. Ready for conversation!');
- }
-}
-
-// Add CSS styles for new UI elements
-document.addEventListener('DOMContentLoaded', function() {
- // Add styles for processing state and timestamps
- const style = document.createElement('style');
- style.textContent = `
- .message.processing {
- opacity: 0.8;
- }
-
- .loading-indicator {
- display: flex;
- align-items: center;
- margin-top: 8px;
- font-size: 0.9em;
- color: #666;
- }
-
- .loading-spinner {
- width: 16px;
- height: 16px;
- border: 2px solid #ddd;
- border-top: 2px solid var(--primary-color);
- border-radius: 50%;
- margin-right: 8px;
- animation: spin 1s linear infinite;
- }
-
- @keyframes spin {
- 0% { transform: rotate(0deg); }
- 100% { transform: rotate(360deg); }
- }
-
- /* Timestamp styles */
- .timestamp-toggle {
- font-size: 0.75em;
- padding: 4px 8px;
- margin-top: 8px;
- background-color: #f0f0f0;
- border: 1px solid #ddd;
- border-radius: 4px;
- cursor: pointer;
- }
-
- .timestamp-toggle:hover {
- background-color: #e0e0e0;
- }
-
- .timestamps-container {
- margin-top: 8px;
- padding: 8px;
- background-color: #f9f9f9;
- border-radius: 4px;
- font-size: 0.85em;
- }
-
- .timestamp {
- margin-bottom: 4px;
- padding: 2px 0;
- }
-
- .timestamp .time {
- color: #666;
- font-family: monospace;
- margin-right: 8px;
- }
-
- .timestamp .text {
- color: #333;
- }
- `;
- document.head.appendChild(style);
-});
-
-// Initialize the application when DOM is fully loaded
-document.addEventListener('DOMContentLoaded', initializeApp);
-
diff --git a/React/src/app/auth/session/route.ts b/React/src/app/auth/session/route.ts
new file mode 100644
index 0000000..9299d4a
--- /dev/null
+++ b/React/src/app/auth/session/route.ts
@@ -0,0 +1,12 @@
+import { NextResponse } from "next/server";
+import { auth0 } from "../../../lib/auth0";
+
+export async function GET() {
+ try {
+ const session = await auth0.getSession();
+ return NextResponse.json({ session });
+ } catch (error) {
+ console.error("Error getting session:", error);
+ return NextResponse.json({ session: null }, { status: 500 });
+ }
+}
diff --git a/React/src/app/page.tsx b/React/src/app/page.tsx
index c91ed2b..21e0862 100644
--- a/React/src/app/page.tsx
+++ b/React/src/app/page.tsx
@@ -1,67 +1,94 @@
-import { useState } from "react";
-import { auth0 } from "../lib/auth0";
+"use client";
+import { useState, useEffect } from "react";
+import { useRouter } from "next/navigation";
-
-export default async function Home() {
-
+export default function Home() {
const [contacts, setContacts] = useState([]);
const [codeword, setCodeword] = useState("");
+ const [session, setSession] = useState(null);
+ const [loading, setLoading] = useState(true);
+ const router = useRouter();
- const session = await auth0.getSession();
-
- console.log("Session:", session?.user);
+ useEffect(() => {
+ // Fetch session data from an API route
+ fetch("/auth/session")
+ .then((response) => response.json())
+ .then((data) => {
+ setSession(data.session);
+ setLoading(false);
+ })
+ .catch((error) => {
+ console.error("Failed to fetch session:", error);
+ setLoading(false);
+ });
+ }, []);
function saveToDB() {
- //e.preventDefault();
alert("Saving contacts...");
- // const contactInputs = document.querySelectorAll(".text-input") as NodeListOf;
- // const contactValues = Array.from(contactInputs).map(input => input.value);
- // console.log("Contact values:", contactValues);
- // // save codeword and contacts to database
- // fetch("/api/databaseStorage", {
- // method: "POST",
- // headers: {
- // "Content-Type": "application/json",
- // },
- // body: JSON.stringify({
- // email: session?.user?.email || "",
- // codeword: (document.getElementById("codeword") as HTMLInputElement)?.value,
- // contacts: contactValues,
- // }),
- // })
- // .then((response) => {
- // if (response.ok) {
- // alert("Contacts saved successfully!");
- // } else {
- // alert("Error saving contacts.");
- // }
- // })
- // .catch((error) => {
- // console.error("Error:", error);
- // alert("Error saving contacts.");
- // });
-
+ const contactInputs = document.querySelectorAll(
+ ".text-input"
+ ) as NodeListOf;
+ const contactValues = Array.from(contactInputs).map((input) => input.value);
+
+ fetch("/api/databaseStorage", {
+ method: "POST",
+ headers: {
+ "Content-Type": "application/json",
+ },
+ body: JSON.stringify({
+ email: session?.user?.email || "",
+ codeword: codeword,
+ contacts: contactValues,
+ }),
+ })
+ .then((response) => {
+ if (response.ok) {
+ alert("Contacts saved successfully!");
+ } else {
+ alert("Error saving contacts.");
+ }
+ })
+ .catch((error) => {
+ console.error("Error:", error);
+ alert("Error saving contacts.");
+ });
}
+ if (loading) {
+ return Loading...
;
+ }
// If no session, show sign-up and login buttons
- if (!session) {
-
+ if (!session) {
return (
-
+
-
+
-
Fauxcall
-
Set emergency contacts
-
If you stop speaking or say the codeword, these contacts will be notified
+
+ Fauxcall
+
+
+ Set emergency contacts
+
+
+ If you stop speaking or say the codeword, these contacts will be
+ notified
+
{/* form for setting codeword */}
-
{/* form for adding contacts */}
-
);
@@ -107,25 +145,42 @@ export default async function Home() {
);
-
-
}