From c8551f90b361c8abe3e73392670a9a8259268b71 Mon Sep 17 00:00:00 2001 From: GamerBoss101 Date: Sun, 30 Mar 2025 03:19:23 -0400 Subject: [PATCH] Demo Fixes 8 --- Backend/server.py | 131 ++++++++++++++++++++++++++++------------------ 1 file changed, 81 insertions(+), 50 deletions(-) diff --git a/Backend/server.py b/Backend/server.py index 8145ab0..e912a9d 100644 --- a/Backend/server.py +++ b/Backend/server.py @@ -60,8 +60,11 @@ class AppModels: generator = None tokenizer = None llm = None - asr_model = None - asr_processor = None + whisperx_model = None + whisperx_align_model = None + whisperx_align_metadata = None + diarize_model = None + last_language = None # Initialize the models object models = AppModels() @@ -87,25 +90,27 @@ def load_models(): logger.error(f"Error loading CSM 1B model: {str(e)}\n{error_details}") socketio.emit('model_status', {'model': 'csm', 'status': 'error', 'message': str(e)}) - # Whisper loading + # WhisperX loading try: socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 40, 'message': 'Loading speech recognition model'}) - # Use regular Whisper instead of WhisperX to avoid compatibility issues - from transformers import WhisperProcessor, WhisperForConditionalGeneration + # Use WhisperX for better transcription with timestamps + import whisperx - # Use a smaller model for faster processing - model_id = "openai/whisper-small" + # Use compute_type based on device + compute_type = "float16" if DEVICE == "cuda" else "float32" - models.asr_processor = WhisperProcessor.from_pretrained(model_id) - models.asr_model = WhisperForConditionalGeneration.from_pretrained(model_id).to(DEVICE) + # Load the WhisperX model (smaller model for faster processing) + models.whisperx_model = whisperx.load_model("small", DEVICE, compute_type=compute_type) - logger.info("Whisper ASR model loaded successfully") + logger.info("WhisperX model loaded successfully") socketio.emit('model_status', {'model': 'asr', 'status': 'loaded'}) socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 66}) if DEVICE == "cuda": torch.cuda.empty_cache() except Exception as e: - logger.error(f"Error loading ASR model: {str(e)}") + import traceback + error_details = traceback.format_exc() + logger.error(f"Error loading WhisperX model: {str(e)}\n{error_details}") socketio.emit('model_status', {'model': 'asr', 'status': 'error', 'message': str(e)}) # Llama loading @@ -184,7 +189,7 @@ def system_status(): "device": DEVICE, "models": { "generator": models.generator is not None, - "asr": models.asr_model is not None, # Use the correct model name + "asr": models.whisperx_model is not None, # Use the correct model name "llm": models.llm is not None } }) @@ -327,8 +332,8 @@ def process_audio_queue(session_id, q): del user_queues[session_id] def process_audio_and_respond(session_id, data): - """Process audio data and generate a response using standard Whisper""" - if models.generator is None or models.asr_model is None or models.llm is None: + """Process audio data and generate a response using WhisperX""" + if models.generator is None or models.whisperx_model is None or models.llm is None: logger.warning("Models not yet loaded!") with app.app_context(): socketio.emit('error', {'message': 'Models still loading, please wait'}, room=session_id) @@ -364,44 +369,69 @@ def process_audio_and_respond(session_id, data): with app.app_context(): socketio.emit('processing_status', {'status': 'transcribing'}, room=session_id) - # Load audio for ASR processing - import librosa - speech_array, sampling_rate = librosa.load(temp_path, sr=16000) + # Load audio using WhisperX + import whisperx + audio = whisperx.load_audio(temp_path) - # Convert to required format - processor_output = models.asr_processor( - speech_array, - sampling_rate=sampling_rate, - return_tensors="pt", - padding=True, # Add padding - return_attention_mask=True # Request attention mask - ) - input_features = processor_output.input_features.to(DEVICE) - attention_mask = processor_output.get('attention_mask', None) - - if attention_mask is not None: - attention_mask = attention_mask.to(DEVICE) + # Check audio length and add a warning for short clips + audio_length = len(audio) / 16000 # assuming 16kHz sample rate + if audio_length < 1.0: + logger.warning(f"Audio is very short ({audio_length:.2f}s), may affect transcription quality") + + # Transcribe using WhisperX + batch_size = 16 # adjust based on your GPU memory + logger.info("Running WhisperX transcription...") + + # Handle the warning about audio being shorter than 30s by suppressing it + import warnings + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", message="audio is shorter than 30s") + result = models.whisperx_model.transcribe(audio, batch_size=batch_size) + + # Get the detected language + language_code = result["language"] + logger.info(f"Detected language: {language_code}") + + # Check if alignment model needs to be loaded or updated + if models.whisperx_align_model is None or language_code != models.last_language: + # Clean up old models if they exist + if models.whisperx_align_model is not None: + del models.whisperx_align_model + del models.whisperx_align_metadata + if DEVICE == "cuda": + gc.collect() + torch.cuda.empty_cache() - # Generate token ids with attention mask - predicted_ids = models.asr_model.generate( - input_features, - attention_mask=attention_mask, - language="en", - task="transcribe" - ) - else: - # Fallback if attention mask is not available - predicted_ids = models.asr_model.generate( - input_features, - language="en", - task="transcribe" + # Load new alignment model for the detected language + logger.info(f"Loading alignment model for language: {language_code}") + models.whisperx_align_model, models.whisperx_align_metadata = whisperx.load_align_model( + language_code=language_code, device=DEVICE ) + models.last_language = language_code - # Decode the predicted ids to text - user_text = models.asr_processor.batch_decode( - predicted_ids, - skip_special_tokens=True - )[0] + # Align the transcript to get word-level timestamps + if result["segments"] and len(result["segments"]) > 0: + logger.info("Aligning transcript...") + result = whisperx.align( + result["segments"], + models.whisperx_align_model, + models.whisperx_align_metadata, + audio, + DEVICE, + return_char_alignments=False + ) + + # Process the segments for better output + for segment in result["segments"]: + # Round timestamps for better display + segment["start"] = round(segment["start"], 2) + segment["end"] = round(segment["end"], 2) + # Add a confidence score if not present + if "confidence" not in segment: + segment["confidence"] = 1.0 # Default confidence + + # Extract the full text from all segments + user_text = ' '.join([segment['text'] for segment in result['segments']]) # If no text was recognized, don't process further if not user_text or len(user_text.strip()) == 0: @@ -433,11 +463,12 @@ def process_audio_and_respond(session_id, data): audio=waveform.squeeze() ) - # Send transcription to client + # Send transcription to client with detailed segments with app.app_context(): socketio.emit('transcription', { 'text': user_text, - 'speaker': speaker_id + 'speaker': speaker_id, + 'segments': result['segments'] # Include the detailed segments with timestamps }, room=session_id) # Generate AI response using Llama