Demo Update 20
This commit is contained in:
@@ -56,12 +56,8 @@ class AppModels:
|
||||
generator = None
|
||||
tokenizer = None
|
||||
llm = None
|
||||
whisperx_model = None
|
||||
whisperx_align_model = None
|
||||
whisperx_align_metadata = None
|
||||
diarize_model = None
|
||||
|
||||
models = AppModels()
|
||||
asr_model = None
|
||||
asr_processor = None
|
||||
|
||||
def load_models():
|
||||
"""Load all required models"""
|
||||
@@ -76,16 +72,22 @@ def load_models():
|
||||
logger.error(f"Error loading CSM 1B model: {str(e)}")
|
||||
socketio.emit('model_status', {'model': 'csm', 'status': 'error', 'message': str(e)})
|
||||
|
||||
logger.info("Loading WhisperX model...")
|
||||
logger.info("Loading Whisper ASR model...")
|
||||
try:
|
||||
# Use WhisperX instead of the regular Whisper
|
||||
compute_type = "float16" if DEVICE == "cuda" else "float32"
|
||||
models.whisperx_model = whisperx.load_model("large-v2", DEVICE, compute_type=compute_type)
|
||||
logger.info("WhisperX model loaded successfully")
|
||||
socketio.emit('model_status', {'model': 'whisperx', 'status': 'loaded'})
|
||||
# Use regular Whisper instead of WhisperX to avoid compatibility issues
|
||||
from transformers import WhisperProcessor, WhisperForConditionalGeneration
|
||||
|
||||
# Use a smaller model for faster processing
|
||||
model_id = "openai/whisper-small"
|
||||
|
||||
models.asr_processor = WhisperProcessor.from_pretrained(model_id)
|
||||
models.asr_model = WhisperForConditionalGeneration.from_pretrained(model_id).to(DEVICE)
|
||||
|
||||
logger.info("Whisper ASR model loaded successfully")
|
||||
socketio.emit('model_status', {'model': 'asr', 'status': 'loaded'})
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading WhisperX model: {str(e)}")
|
||||
socketio.emit('model_status', {'model': 'whisperx', 'status': 'error', 'message': str(e)})
|
||||
logger.error(f"Error loading ASR model: {str(e)}")
|
||||
socketio.emit('model_status', {'model': 'asr', 'status': 'error', 'message': str(e)})
|
||||
|
||||
logger.info("Loading Llama 3.2 model...")
|
||||
try:
|
||||
@@ -141,7 +143,8 @@ def health_check():
|
||||
"models_loaded": models.generator is not None and models.llm is not None
|
||||
})
|
||||
|
||||
# Add a system status endpoint
|
||||
# Fix the system_status function:
|
||||
|
||||
@app.route('/api/status')
|
||||
def system_status():
|
||||
return jsonify({
|
||||
@@ -150,7 +153,7 @@ def system_status():
|
||||
"device": DEVICE,
|
||||
"models": {
|
||||
"generator": models.generator is not None,
|
||||
"whisperx": models.whisperx_model is not None,
|
||||
"asr": models.asr_model is not None, # Use the correct model name
|
||||
"llm": models.llm is not None
|
||||
}
|
||||
})
|
||||
@@ -260,8 +263,8 @@ def process_audio_queue(session_id, q):
|
||||
del user_queues[session_id]
|
||||
|
||||
def process_audio_and_respond(session_id, data):
|
||||
"""Process audio data and generate a response using WhisperX"""
|
||||
if models.generator is None or models.whisperx_model is None or models.llm is None:
|
||||
"""Process audio data and generate a response using standard Whisper"""
|
||||
if models.generator is None or models.asr_model is None or models.llm is None:
|
||||
logger.warning("Models not yet loaded!")
|
||||
with app.app_context():
|
||||
socketio.emit('error', {'message': 'Models still loading, please wait'}, room=session_id)
|
||||
@@ -293,47 +296,33 @@ def process_audio_and_respond(session_id, data):
|
||||
temp_path = temp_file.name
|
||||
|
||||
try:
|
||||
# Load audio using WhisperX
|
||||
# Notify client that transcription is starting
|
||||
with app.app_context():
|
||||
socketio.emit('processing_status', {'status': 'transcribing'}, room=session_id)
|
||||
|
||||
# Load audio with WhisperX instead of torchaudio
|
||||
audio = whisperx.load_audio(temp_path)
|
||||
# Load audio for ASR processing
|
||||
import librosa
|
||||
speech_array, sampling_rate = librosa.load(temp_path, sr=16000)
|
||||
|
||||
# Transcribe using WhisperX
|
||||
batch_size = 16 # Adjust based on available memory
|
||||
result = models.whisperx_model.transcribe(audio, batch_size=batch_size)
|
||||
# Convert to required format
|
||||
input_features = models.asr_processor(
|
||||
speech_array,
|
||||
sampling_rate=sampling_rate,
|
||||
return_tensors="pt"
|
||||
).input_features.to(DEVICE)
|
||||
|
||||
# Get the detected language
|
||||
language_code = result["language"]
|
||||
logger.info(f"Detected language: {language_code}")
|
||||
|
||||
# Load alignment model if not already loaded
|
||||
if models.whisperx_align_model is None or language_code != getattr(models, 'last_language', None):
|
||||
# Clear previous models to save memory
|
||||
if models.whisperx_align_model is not None:
|
||||
del models.whisperx_align_model
|
||||
del models.whisperx_align_metadata
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache() if DEVICE == "cuda" else None
|
||||
|
||||
models.whisperx_align_model, models.whisperx_align_metadata = whisperx.load_align_model(
|
||||
language_code=language_code, device=DEVICE
|
||||
)
|
||||
models.last_language = language_code
|
||||
|
||||
# Align the transcript
|
||||
result = whisperx.align(
|
||||
result["segments"],
|
||||
models.whisperx_align_model,
|
||||
models.whisperx_align_metadata,
|
||||
audio,
|
||||
DEVICE,
|
||||
return_char_alignments=False
|
||||
# Generate token ids
|
||||
predicted_ids = models.asr_model.generate(
|
||||
input_features,
|
||||
language="en",
|
||||
task="transcribe"
|
||||
)
|
||||
|
||||
# Combine all segments into a single transcript
|
||||
user_text = ' '.join([segment['text'] for segment in result['segments']])
|
||||
# Decode the predicted ids to text
|
||||
user_text = models.asr_processor.batch_decode(
|
||||
predicted_ids,
|
||||
skip_special_tokens=True
|
||||
)[0]
|
||||
|
||||
# If no text was recognized, don't process further
|
||||
if not user_text or len(user_text.strip()) == 0:
|
||||
@@ -369,8 +358,7 @@ def process_audio_and_respond(session_id, data):
|
||||
with app.app_context():
|
||||
socketio.emit('transcription', {
|
||||
'text': user_text,
|
||||
'speaker': speaker_id,
|
||||
'segments': result['segments'] # Send detailed segments info
|
||||
'speaker': speaker_id
|
||||
}, room=session_id)
|
||||
|
||||
# Generate AI response using Llama
|
||||
|
||||
Reference in New Issue
Block a user