From 10902f1d713b13295d762bb76b55709f531c34ba Mon Sep 17 00:00:00 2001 From: GamerBoss101 Date: Sun, 30 Mar 2025 02:40:08 -0400 Subject: [PATCH] Demo Update 18 --- Backend/server.py | 122 +++++++++++++++++++++++++++--------------- Backend/voice-chat.js | 94 +++++++++++++++++++++++++++++++- 2 files changed, 171 insertions(+), 45 deletions(-) diff --git a/Backend/server.py b/Backend/server.py index 992e674..4cc4f91 100644 --- a/Backend/server.py +++ b/Backend/server.py @@ -8,6 +8,7 @@ import logging import threading import queue import tempfile +import gc from typing import Dict, List, Optional, Tuple import torch @@ -18,6 +19,9 @@ from flask_socketio import SocketIO, emit from flask_cors import CORS from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline +# Import WhisperX for better transcription +import whisperx + from generator import load_csm_1b, Segment from dataclasses import dataclass @@ -52,7 +56,10 @@ class AppModels: generator = None tokenizer = None llm = None - asr = None + whisperx_model = None + whisperx_align_model = None + whisperx_align_metadata = None + diarize_model = None models = AppModels() @@ -69,22 +76,16 @@ def load_models(): logger.error(f"Error loading CSM 1B model: {str(e)}") socketio.emit('model_status', {'model': 'csm', 'status': 'error', 'message': str(e)}) - logger.info("Loading ASR pipeline...") + logger.info("Loading WhisperX model...") try: - # Initialize the pipeline without the language parameter in the constructor - models.asr = pipeline( - "automatic-speech-recognition", - model="openai/whisper-small", - device=DEVICE - ) - - # Configure the model with the appropriate options - # Note that for whisper, language should be set during inference, not initialization - logger.info("ASR pipeline loaded successfully") - socketio.emit('model_status', {'model': 'asr', 'status': 'loaded'}) + # Use WhisperX instead of the regular Whisper + compute_type = "float16" if DEVICE == "cuda" else "float32" + models.whisperx_model = whisperx.load_model("large-v2", DEVICE, compute_type=compute_type) + logger.info("WhisperX model loaded successfully") + socketio.emit('model_status', {'model': 'whisperx', 'status': 'loaded'}) except Exception as e: - logger.error(f"Error loading ASR pipeline: {str(e)}") - socketio.emit('model_status', {'model': 'asr', 'status': 'error', 'message': str(e)}) + logger.error(f"Error loading WhisperX model: {str(e)}") + socketio.emit('model_status', {'model': 'whisperx', 'status': 'error', 'message': str(e)}) logger.info("Loading Llama 3.2 model...") try: @@ -149,7 +150,7 @@ def system_status(): "device": DEVICE, "models": { "generator": models.generator is not None, - "asr": models.asr is not None, + "whisperx": models.whisperx_model is not None, "llm": models.llm is not None } }) @@ -259,8 +260,8 @@ def process_audio_queue(session_id, q): del user_queues[session_id] def process_audio_and_respond(session_id, data): - """Process audio data and generate a response""" - if models.generator is None or models.asr is None or models.llm is None: + """Process audio data and generate a response using WhisperX""" + if models.generator is None or models.whisperx_model is None or models.llm is None: logger.warning("Models not yet loaded!") with app.app_context(): socketio.emit('error', {'message': 'Models still loading, please wait'}, room=session_id) @@ -292,7 +293,57 @@ def process_audio_and_respond(session_id, data): temp_path = temp_file.name try: - # Load audio file + # Load audio using WhisperX + with app.app_context(): + socketio.emit('processing_status', {'status': 'transcribing'}, room=session_id) + + # Load audio with WhisperX instead of torchaudio + audio = whisperx.load_audio(temp_path) + + # Transcribe using WhisperX + batch_size = 16 # Adjust based on available memory + result = models.whisperx_model.transcribe(audio, batch_size=batch_size) + + # Get the detected language + language_code = result["language"] + logger.info(f"Detected language: {language_code}") + + # Load alignment model if not already loaded + if models.whisperx_align_model is None or language_code != getattr(models, 'last_language', None): + # Clear previous models to save memory + if models.whisperx_align_model is not None: + del models.whisperx_align_model + del models.whisperx_align_metadata + gc.collect() + torch.cuda.empty_cache() if DEVICE == "cuda" else None + + models.whisperx_align_model, models.whisperx_align_metadata = whisperx.load_align_model( + language_code=language_code, device=DEVICE + ) + models.last_language = language_code + + # Align the transcript + result = whisperx.align( + result["segments"], + models.whisperx_align_model, + models.whisperx_align_metadata, + audio, + DEVICE, + return_char_alignments=False + ) + + # Combine all segments into a single transcript + user_text = ' '.join([segment['text'] for segment in result['segments']]) + + # If no text was recognized, don't process further + if not user_text or len(user_text.strip()) == 0: + with app.app_context(): + socketio.emit('error', {'message': 'No speech detected'}, room=session_id) + return + + logger.info(f"Transcription: {user_text}") + + # Load audio for CSM input waveform, sample_rate = torchaudio.load(temp_path) # Normalize to mono if needed @@ -307,24 +358,6 @@ def process_audio_and_respond(session_id, data): new_freq=models.generator.sample_rate ) - # Transcribe audio - with app.app_context(): - socketio.emit('processing_status', {'status': 'transcribing'}, room=session_id) - - # Use the ASR pipeline to transcribe - transcription_result = models.asr( - {"array": waveform.squeeze().cpu().numpy(), "sampling_rate": models.generator.sample_rate}, - return_timestamps=False, - generate_kwargs={"language": "en"} # Set language during inference - ) - user_text = transcription_result['text'].strip() - - # If no text was recognized, don't process further - if not user_text: - with app.app_context(): - socketio.emit('error', {'message': 'No speech detected'}, room=session_id) - return - # Add the user's message to conversation history user_segment = conversation.add_segment( text=user_text, @@ -336,7 +369,8 @@ def process_audio_and_respond(session_id, data): with app.app_context(): socketio.emit('transcription', { 'text': user_text, - 'speaker': speaker_id + 'speaker': speaker_id, + 'segments': result['segments'] # Send detailed segments info }, room=session_id) # Generate AI response using Llama @@ -365,7 +399,7 @@ def process_audio_and_respond(session_id, data): with torch.no_grad(): generated_ids = models.llm.generate( input_ids, - attention_mask=attention_mask, # Add the attention mask + attention_mask=attention_mask, max_new_tokens=100, temperature=0.7, top_p=0.9, @@ -390,13 +424,13 @@ def process_audio_and_respond(session_id, data): 'chunk_index': 0 }, room=session_id) - # Define AI speaker ID (use a consistent value for the AI's voice) - ai_speaker_id = 1 # Use speaker 1 for AI responses + # Define AI speaker ID + ai_speaker_id = conversation.ai_speaker_id # Generate audio audio_tensor = models.generator.generate( text=response_text, - speaker=ai_speaker_id, # Use the local variable instead of conversation.ai_speaker_id + speaker=ai_speaker_id, context=conversation.get_context(), max_audio_length_ms=10_000, temperature=0.9 @@ -405,7 +439,7 @@ def process_audio_and_respond(session_id, data): # Add AI response to conversation history ai_segment = conversation.add_segment( text=response_text, - speaker=ai_speaker_id, # Also use the local variable here + speaker=ai_speaker_id, audio=audio_tensor ) @@ -444,6 +478,8 @@ def process_audio_and_respond(session_id, data): except Exception as e: logger.error(f"Error processing audio: {str(e)}") + import traceback + logger.error(traceback.format_exc()) with app.app_context(): socketio.emit('error', {'message': f'Error: {str(e)}'}, room=session_id) finally: diff --git a/Backend/voice-chat.js b/Backend/voice-chat.js index 5c3f247..705d5ab 100644 --- a/Backend/voice-chat.js +++ b/Backend/voice-chat.js @@ -627,7 +627,59 @@ function addSystemMessage(text) { // Handle transcription response from server function handleTranscription(data) { const speaker = data.speaker === 0 ? 'user' : 'ai'; - addMessage(data.text, speaker); + + // Create the message div + const messageDiv = addMessage(data.text, speaker); + + // If we have detailed segments from WhisperX, add timestamps + if (data.segments && data.segments.length > 0) { + // Add a timestamps container + const timestampsContainer = document.createElement('div'); + timestampsContainer.className = 'timestamps-container'; + timestampsContainer.style.display = 'none'; // Hidden by default + + // Add a toggle button + const toggleButton = document.createElement('button'); + toggleButton.className = 'timestamp-toggle'; + toggleButton.textContent = 'Show Timestamps'; + toggleButton.onclick = function() { + const isHidden = timestampsContainer.style.display === 'none'; + timestampsContainer.style.display = isHidden ? 'block' : 'none'; + toggleButton.textContent = isHidden ? 'Hide Timestamps' : 'Show Timestamps'; + }; + + // Add timestamps for each segment + data.segments.forEach(segment => { + const timestampDiv = document.createElement('div'); + timestampDiv.className = 'timestamp'; + + // Format start and end times + const startTime = formatTime(segment.start); + const endTime = formatTime(segment.end); + + timestampDiv.innerHTML = ` + [${startTime} - ${endTime}] + ${segment.text} + `; + + timestampsContainer.appendChild(timestampDiv); + }); + + // Add the timestamp elements to the message + messageDiv.appendChild(toggleButton); + messageDiv.appendChild(timestampsContainer); + } + + return messageDiv; +} + +// Helper function to format time in seconds to MM:SS.ms format +function formatTime(seconds) { + const mins = Math.floor(seconds / 60); + const secs = Math.floor(seconds % 60); + const ms = Math.floor((seconds % 1) * 1000); + + return `${mins.toString().padStart(2, '0')}:${secs.toString().padStart(2, '0')}.${ms.toString().padStart(3, '0')}`; } // Handle context update from server @@ -804,7 +856,7 @@ function finalizeStreamingAudio() { // Add CSS styles for new UI elements document.addEventListener('DOMContentLoaded', function() { - // Add styles for processing state + // Add styles for processing state and timestamps const style = document.createElement('style'); style.textContent = ` .message.processing { @@ -833,6 +885,44 @@ document.addEventListener('DOMContentLoaded', function() { 0% { transform: rotate(0deg); } 100% { transform: rotate(360deg); } } + + /* Timestamp styles */ + .timestamp-toggle { + font-size: 0.75em; + padding: 4px 8px; + margin-top: 8px; + background-color: #f0f0f0; + border: 1px solid #ddd; + border-radius: 4px; + cursor: pointer; + } + + .timestamp-toggle:hover { + background-color: #e0e0e0; + } + + .timestamps-container { + margin-top: 8px; + padding: 8px; + background-color: #f9f9f9; + border-radius: 4px; + font-size: 0.85em; + } + + .timestamp { + margin-bottom: 4px; + padding: 2px 0; + } + + .timestamp .time { + color: #666; + font-family: monospace; + margin-right: 8px; + } + + .timestamp .text { + color: #333; + } `; document.head.appendChild(style); });