Demo Fixes 8

This commit is contained in:
2025-03-30 03:19:23 -04:00
parent 284dd50972
commit c8551f90b3

View File

@@ -60,8 +60,11 @@ class AppModels:
generator = None
tokenizer = None
llm = None
asr_model = None
asr_processor = None
whisperx_model = None
whisperx_align_model = None
whisperx_align_metadata = None
diarize_model = None
last_language = None
# Initialize the models object
models = AppModels()
@@ -87,25 +90,27 @@ def load_models():
logger.error(f"Error loading CSM 1B model: {str(e)}\n{error_details}")
socketio.emit('model_status', {'model': 'csm', 'status': 'error', 'message': str(e)})
# Whisper loading
# WhisperX loading
try:
socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 40, 'message': 'Loading speech recognition model'})
# Use regular Whisper instead of WhisperX to avoid compatibility issues
from transformers import WhisperProcessor, WhisperForConditionalGeneration
# Use WhisperX for better transcription with timestamps
import whisperx
# Use a smaller model for faster processing
model_id = "openai/whisper-small"
# Use compute_type based on device
compute_type = "float16" if DEVICE == "cuda" else "float32"
models.asr_processor = WhisperProcessor.from_pretrained(model_id)
models.asr_model = WhisperForConditionalGeneration.from_pretrained(model_id).to(DEVICE)
# Load the WhisperX model (smaller model for faster processing)
models.whisperx_model = whisperx.load_model("small", DEVICE, compute_type=compute_type)
logger.info("Whisper ASR model loaded successfully")
logger.info("WhisperX model loaded successfully")
socketio.emit('model_status', {'model': 'asr', 'status': 'loaded'})
socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 66})
if DEVICE == "cuda":
torch.cuda.empty_cache()
except Exception as e:
logger.error(f"Error loading ASR model: {str(e)}")
import traceback
error_details = traceback.format_exc()
logger.error(f"Error loading WhisperX model: {str(e)}\n{error_details}")
socketio.emit('model_status', {'model': 'asr', 'status': 'error', 'message': str(e)})
# Llama loading
@@ -184,7 +189,7 @@ def system_status():
"device": DEVICE,
"models": {
"generator": models.generator is not None,
"asr": models.asr_model is not None, # Use the correct model name
"asr": models.whisperx_model is not None, # Use the correct model name
"llm": models.llm is not None
}
})
@@ -327,8 +332,8 @@ def process_audio_queue(session_id, q):
del user_queues[session_id]
def process_audio_and_respond(session_id, data):
"""Process audio data and generate a response using standard Whisper"""
if models.generator is None or models.asr_model is None or models.llm is None:
"""Process audio data and generate a response using WhisperX"""
if models.generator is None or models.whisperx_model is None or models.llm is None:
logger.warning("Models not yet loaded!")
with app.app_context():
socketio.emit('error', {'message': 'Models still loading, please wait'}, room=session_id)
@@ -364,44 +369,69 @@ def process_audio_and_respond(session_id, data):
with app.app_context():
socketio.emit('processing_status', {'status': 'transcribing'}, room=session_id)
# Load audio for ASR processing
import librosa
speech_array, sampling_rate = librosa.load(temp_path, sr=16000)
# Load audio using WhisperX
import whisperx
audio = whisperx.load_audio(temp_path)
# Convert to required format
processor_output = models.asr_processor(
speech_array,
sampling_rate=sampling_rate,
return_tensors="pt",
padding=True, # Add padding
return_attention_mask=True # Request attention mask
)
input_features = processor_output.input_features.to(DEVICE)
attention_mask = processor_output.get('attention_mask', None)
if attention_mask is not None:
attention_mask = attention_mask.to(DEVICE)
# Check audio length and add a warning for short clips
audio_length = len(audio) / 16000 # assuming 16kHz sample rate
if audio_length < 1.0:
logger.warning(f"Audio is very short ({audio_length:.2f}s), may affect transcription quality")
# Transcribe using WhisperX
batch_size = 16 # adjust based on your GPU memory
logger.info("Running WhisperX transcription...")
# Handle the warning about audio being shorter than 30s by suppressing it
import warnings
with warnings.catch_warnings():
warnings.filterwarnings("ignore", message="audio is shorter than 30s")
result = models.whisperx_model.transcribe(audio, batch_size=batch_size)
# Get the detected language
language_code = result["language"]
logger.info(f"Detected language: {language_code}")
# Check if alignment model needs to be loaded or updated
if models.whisperx_align_model is None or language_code != models.last_language:
# Clean up old models if they exist
if models.whisperx_align_model is not None:
del models.whisperx_align_model
del models.whisperx_align_metadata
if DEVICE == "cuda":
gc.collect()
torch.cuda.empty_cache()
# Generate token ids with attention mask
predicted_ids = models.asr_model.generate(
input_features,
attention_mask=attention_mask,
language="en",
task="transcribe"
)
else:
# Fallback if attention mask is not available
predicted_ids = models.asr_model.generate(
input_features,
language="en",
task="transcribe"
# Load new alignment model for the detected language
logger.info(f"Loading alignment model for language: {language_code}")
models.whisperx_align_model, models.whisperx_align_metadata = whisperx.load_align_model(
language_code=language_code, device=DEVICE
)
models.last_language = language_code
# Decode the predicted ids to text
user_text = models.asr_processor.batch_decode(
predicted_ids,
skip_special_tokens=True
)[0]
# Align the transcript to get word-level timestamps
if result["segments"] and len(result["segments"]) > 0:
logger.info("Aligning transcript...")
result = whisperx.align(
result["segments"],
models.whisperx_align_model,
models.whisperx_align_metadata,
audio,
DEVICE,
return_char_alignments=False
)
# Process the segments for better output
for segment in result["segments"]:
# Round timestamps for better display
segment["start"] = round(segment["start"], 2)
segment["end"] = round(segment["end"], 2)
# Add a confidence score if not present
if "confidence" not in segment:
segment["confidence"] = 1.0 # Default confidence
# Extract the full text from all segments
user_text = ' '.join([segment['text'] for segment in result['segments']])
# If no text was recognized, don't process further
if not user_text or len(user_text.strip()) == 0:
@@ -433,11 +463,12 @@ def process_audio_and_respond(session_id, data):
audio=waveform.squeeze()
)
# Send transcription to client
# Send transcription to client with detailed segments
with app.app_context():
socketio.emit('transcription', {
'text': user_text,
'speaker': speaker_id
'speaker': speaker_id,
'segments': result['segments'] # Include the detailed segments with timestamps
}, room=session_id)
# Generate AI response using Llama