Demo Update 18

This commit is contained in:
2025-03-30 02:40:08 -04:00
parent faacb612e7
commit 10902f1d71
2 changed files with 171 additions and 45 deletions

View File

@@ -8,6 +8,7 @@ import logging
import threading
import queue
import tempfile
import gc
from typing import Dict, List, Optional, Tuple
import torch
@@ -18,6 +19,9 @@ from flask_socketio import SocketIO, emit
from flask_cors import CORS
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
# Import WhisperX for better transcription
import whisperx
from generator import load_csm_1b, Segment
from dataclasses import dataclass
@@ -52,7 +56,10 @@ class AppModels:
generator = None
tokenizer = None
llm = None
asr = None
whisperx_model = None
whisperx_align_model = None
whisperx_align_metadata = None
diarize_model = None
models = AppModels()
@@ -69,22 +76,16 @@ def load_models():
logger.error(f"Error loading CSM 1B model: {str(e)}")
socketio.emit('model_status', {'model': 'csm', 'status': 'error', 'message': str(e)})
logger.info("Loading ASR pipeline...")
logger.info("Loading WhisperX model...")
try:
# Initialize the pipeline without the language parameter in the constructor
models.asr = pipeline(
"automatic-speech-recognition",
model="openai/whisper-small",
device=DEVICE
)
# Configure the model with the appropriate options
# Note that for whisper, language should be set during inference, not initialization
logger.info("ASR pipeline loaded successfully")
socketio.emit('model_status', {'model': 'asr', 'status': 'loaded'})
# Use WhisperX instead of the regular Whisper
compute_type = "float16" if DEVICE == "cuda" else "float32"
models.whisperx_model = whisperx.load_model("large-v2", DEVICE, compute_type=compute_type)
logger.info("WhisperX model loaded successfully")
socketio.emit('model_status', {'model': 'whisperx', 'status': 'loaded'})
except Exception as e:
logger.error(f"Error loading ASR pipeline: {str(e)}")
socketio.emit('model_status', {'model': 'asr', 'status': 'error', 'message': str(e)})
logger.error(f"Error loading WhisperX model: {str(e)}")
socketio.emit('model_status', {'model': 'whisperx', 'status': 'error', 'message': str(e)})
logger.info("Loading Llama 3.2 model...")
try:
@@ -149,7 +150,7 @@ def system_status():
"device": DEVICE,
"models": {
"generator": models.generator is not None,
"asr": models.asr is not None,
"whisperx": models.whisperx_model is not None,
"llm": models.llm is not None
}
})
@@ -259,8 +260,8 @@ def process_audio_queue(session_id, q):
del user_queues[session_id]
def process_audio_and_respond(session_id, data):
"""Process audio data and generate a response"""
if models.generator is None or models.asr is None or models.llm is None:
"""Process audio data and generate a response using WhisperX"""
if models.generator is None or models.whisperx_model is None or models.llm is None:
logger.warning("Models not yet loaded!")
with app.app_context():
socketio.emit('error', {'message': 'Models still loading, please wait'}, room=session_id)
@@ -292,7 +293,57 @@ def process_audio_and_respond(session_id, data):
temp_path = temp_file.name
try:
# Load audio file
# Load audio using WhisperX
with app.app_context():
socketio.emit('processing_status', {'status': 'transcribing'}, room=session_id)
# Load audio with WhisperX instead of torchaudio
audio = whisperx.load_audio(temp_path)
# Transcribe using WhisperX
batch_size = 16 # Adjust based on available memory
result = models.whisperx_model.transcribe(audio, batch_size=batch_size)
# Get the detected language
language_code = result["language"]
logger.info(f"Detected language: {language_code}")
# Load alignment model if not already loaded
if models.whisperx_align_model is None or language_code != getattr(models, 'last_language', None):
# Clear previous models to save memory
if models.whisperx_align_model is not None:
del models.whisperx_align_model
del models.whisperx_align_metadata
gc.collect()
torch.cuda.empty_cache() if DEVICE == "cuda" else None
models.whisperx_align_model, models.whisperx_align_metadata = whisperx.load_align_model(
language_code=language_code, device=DEVICE
)
models.last_language = language_code
# Align the transcript
result = whisperx.align(
result["segments"],
models.whisperx_align_model,
models.whisperx_align_metadata,
audio,
DEVICE,
return_char_alignments=False
)
# Combine all segments into a single transcript
user_text = ' '.join([segment['text'] for segment in result['segments']])
# If no text was recognized, don't process further
if not user_text or len(user_text.strip()) == 0:
with app.app_context():
socketio.emit('error', {'message': 'No speech detected'}, room=session_id)
return
logger.info(f"Transcription: {user_text}")
# Load audio for CSM input
waveform, sample_rate = torchaudio.load(temp_path)
# Normalize to mono if needed
@@ -307,24 +358,6 @@ def process_audio_and_respond(session_id, data):
new_freq=models.generator.sample_rate
)
# Transcribe audio
with app.app_context():
socketio.emit('processing_status', {'status': 'transcribing'}, room=session_id)
# Use the ASR pipeline to transcribe
transcription_result = models.asr(
{"array": waveform.squeeze().cpu().numpy(), "sampling_rate": models.generator.sample_rate},
return_timestamps=False,
generate_kwargs={"language": "en"} # Set language during inference
)
user_text = transcription_result['text'].strip()
# If no text was recognized, don't process further
if not user_text:
with app.app_context():
socketio.emit('error', {'message': 'No speech detected'}, room=session_id)
return
# Add the user's message to conversation history
user_segment = conversation.add_segment(
text=user_text,
@@ -336,7 +369,8 @@ def process_audio_and_respond(session_id, data):
with app.app_context():
socketio.emit('transcription', {
'text': user_text,
'speaker': speaker_id
'speaker': speaker_id,
'segments': result['segments'] # Send detailed segments info
}, room=session_id)
# Generate AI response using Llama
@@ -365,7 +399,7 @@ def process_audio_and_respond(session_id, data):
with torch.no_grad():
generated_ids = models.llm.generate(
input_ids,
attention_mask=attention_mask, # Add the attention mask
attention_mask=attention_mask,
max_new_tokens=100,
temperature=0.7,
top_p=0.9,
@@ -390,13 +424,13 @@ def process_audio_and_respond(session_id, data):
'chunk_index': 0
}, room=session_id)
# Define AI speaker ID (use a consistent value for the AI's voice)
ai_speaker_id = 1 # Use speaker 1 for AI responses
# Define AI speaker ID
ai_speaker_id = conversation.ai_speaker_id
# Generate audio
audio_tensor = models.generator.generate(
text=response_text,
speaker=ai_speaker_id, # Use the local variable instead of conversation.ai_speaker_id
speaker=ai_speaker_id,
context=conversation.get_context(),
max_audio_length_ms=10_000,
temperature=0.9
@@ -405,7 +439,7 @@ def process_audio_and_respond(session_id, data):
# Add AI response to conversation history
ai_segment = conversation.add_segment(
text=response_text,
speaker=ai_speaker_id, # Also use the local variable here
speaker=ai_speaker_id,
audio=audio_tensor
)
@@ -444,6 +478,8 @@ def process_audio_and_respond(session_id, data):
except Exception as e:
logger.error(f"Error processing audio: {str(e)}")
import traceback
logger.error(traceback.format_exc())
with app.app_context():
socketio.emit('error', {'message': f'Error: {str(e)}'}, room=session_id)
finally: