Demo Update 18

This commit is contained in:
2025-03-30 02:40:08 -04:00
parent faacb612e7
commit 10902f1d71
2 changed files with 171 additions and 45 deletions

View File

@@ -8,6 +8,7 @@ import logging
import threading
import queue
import tempfile
import gc
from typing import Dict, List, Optional, Tuple
import torch
@@ -18,6 +19,9 @@ from flask_socketio import SocketIO, emit
from flask_cors import CORS
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
# Import WhisperX for better transcription
import whisperx
from generator import load_csm_1b, Segment
from dataclasses import dataclass
@@ -52,7 +56,10 @@ class AppModels:
generator = None
tokenizer = None
llm = None
asr = None
whisperx_model = None
whisperx_align_model = None
whisperx_align_metadata = None
diarize_model = None
models = AppModels()
@@ -69,22 +76,16 @@ def load_models():
logger.error(f"Error loading CSM 1B model: {str(e)}")
socketio.emit('model_status', {'model': 'csm', 'status': 'error', 'message': str(e)})
logger.info("Loading ASR pipeline...")
logger.info("Loading WhisperX model...")
try:
# Initialize the pipeline without the language parameter in the constructor
models.asr = pipeline(
"automatic-speech-recognition",
model="openai/whisper-small",
device=DEVICE
)
# Configure the model with the appropriate options
# Note that for whisper, language should be set during inference, not initialization
logger.info("ASR pipeline loaded successfully")
socketio.emit('model_status', {'model': 'asr', 'status': 'loaded'})
# Use WhisperX instead of the regular Whisper
compute_type = "float16" if DEVICE == "cuda" else "float32"
models.whisperx_model = whisperx.load_model("large-v2", DEVICE, compute_type=compute_type)
logger.info("WhisperX model loaded successfully")
socketio.emit('model_status', {'model': 'whisperx', 'status': 'loaded'})
except Exception as e:
logger.error(f"Error loading ASR pipeline: {str(e)}")
socketio.emit('model_status', {'model': 'asr', 'status': 'error', 'message': str(e)})
logger.error(f"Error loading WhisperX model: {str(e)}")
socketio.emit('model_status', {'model': 'whisperx', 'status': 'error', 'message': str(e)})
logger.info("Loading Llama 3.2 model...")
try:
@@ -149,7 +150,7 @@ def system_status():
"device": DEVICE,
"models": {
"generator": models.generator is not None,
"asr": models.asr is not None,
"whisperx": models.whisperx_model is not None,
"llm": models.llm is not None
}
})
@@ -259,8 +260,8 @@ def process_audio_queue(session_id, q):
del user_queues[session_id]
def process_audio_and_respond(session_id, data):
"""Process audio data and generate a response"""
if models.generator is None or models.asr is None or models.llm is None:
"""Process audio data and generate a response using WhisperX"""
if models.generator is None or models.whisperx_model is None or models.llm is None:
logger.warning("Models not yet loaded!")
with app.app_context():
socketio.emit('error', {'message': 'Models still loading, please wait'}, room=session_id)
@@ -292,7 +293,57 @@ def process_audio_and_respond(session_id, data):
temp_path = temp_file.name
try:
# Load audio file
# Load audio using WhisperX
with app.app_context():
socketio.emit('processing_status', {'status': 'transcribing'}, room=session_id)
# Load audio with WhisperX instead of torchaudio
audio = whisperx.load_audio(temp_path)
# Transcribe using WhisperX
batch_size = 16 # Adjust based on available memory
result = models.whisperx_model.transcribe(audio, batch_size=batch_size)
# Get the detected language
language_code = result["language"]
logger.info(f"Detected language: {language_code}")
# Load alignment model if not already loaded
if models.whisperx_align_model is None or language_code != getattr(models, 'last_language', None):
# Clear previous models to save memory
if models.whisperx_align_model is not None:
del models.whisperx_align_model
del models.whisperx_align_metadata
gc.collect()
torch.cuda.empty_cache() if DEVICE == "cuda" else None
models.whisperx_align_model, models.whisperx_align_metadata = whisperx.load_align_model(
language_code=language_code, device=DEVICE
)
models.last_language = language_code
# Align the transcript
result = whisperx.align(
result["segments"],
models.whisperx_align_model,
models.whisperx_align_metadata,
audio,
DEVICE,
return_char_alignments=False
)
# Combine all segments into a single transcript
user_text = ' '.join([segment['text'] for segment in result['segments']])
# If no text was recognized, don't process further
if not user_text or len(user_text.strip()) == 0:
with app.app_context():
socketio.emit('error', {'message': 'No speech detected'}, room=session_id)
return
logger.info(f"Transcription: {user_text}")
# Load audio for CSM input
waveform, sample_rate = torchaudio.load(temp_path)
# Normalize to mono if needed
@@ -307,24 +358,6 @@ def process_audio_and_respond(session_id, data):
new_freq=models.generator.sample_rate
)
# Transcribe audio
with app.app_context():
socketio.emit('processing_status', {'status': 'transcribing'}, room=session_id)
# Use the ASR pipeline to transcribe
transcription_result = models.asr(
{"array": waveform.squeeze().cpu().numpy(), "sampling_rate": models.generator.sample_rate},
return_timestamps=False,
generate_kwargs={"language": "en"} # Set language during inference
)
user_text = transcription_result['text'].strip()
# If no text was recognized, don't process further
if not user_text:
with app.app_context():
socketio.emit('error', {'message': 'No speech detected'}, room=session_id)
return
# Add the user's message to conversation history
user_segment = conversation.add_segment(
text=user_text,
@@ -336,7 +369,8 @@ def process_audio_and_respond(session_id, data):
with app.app_context():
socketio.emit('transcription', {
'text': user_text,
'speaker': speaker_id
'speaker': speaker_id,
'segments': result['segments'] # Send detailed segments info
}, room=session_id)
# Generate AI response using Llama
@@ -365,7 +399,7 @@ def process_audio_and_respond(session_id, data):
with torch.no_grad():
generated_ids = models.llm.generate(
input_ids,
attention_mask=attention_mask, # Add the attention mask
attention_mask=attention_mask,
max_new_tokens=100,
temperature=0.7,
top_p=0.9,
@@ -390,13 +424,13 @@ def process_audio_and_respond(session_id, data):
'chunk_index': 0
}, room=session_id)
# Define AI speaker ID (use a consistent value for the AI's voice)
ai_speaker_id = 1 # Use speaker 1 for AI responses
# Define AI speaker ID
ai_speaker_id = conversation.ai_speaker_id
# Generate audio
audio_tensor = models.generator.generate(
text=response_text,
speaker=ai_speaker_id, # Use the local variable instead of conversation.ai_speaker_id
speaker=ai_speaker_id,
context=conversation.get_context(),
max_audio_length_ms=10_000,
temperature=0.9
@@ -405,7 +439,7 @@ def process_audio_and_respond(session_id, data):
# Add AI response to conversation history
ai_segment = conversation.add_segment(
text=response_text,
speaker=ai_speaker_id, # Also use the local variable here
speaker=ai_speaker_id,
audio=audio_tensor
)
@@ -444,6 +478,8 @@ def process_audio_and_respond(session_id, data):
except Exception as e:
logger.error(f"Error processing audio: {str(e)}")
import traceback
logger.error(traceback.format_exc())
with app.app_context():
socketio.emit('error', {'message': f'Error: {str(e)}'}, room=session_id)
finally:

View File

@@ -627,7 +627,59 @@ function addSystemMessage(text) {
// Handle transcription response from server
function handleTranscription(data) {
const speaker = data.speaker === 0 ? 'user' : 'ai';
addMessage(data.text, speaker);
// Create the message div
const messageDiv = addMessage(data.text, speaker);
// If we have detailed segments from WhisperX, add timestamps
if (data.segments && data.segments.length > 0) {
// Add a timestamps container
const timestampsContainer = document.createElement('div');
timestampsContainer.className = 'timestamps-container';
timestampsContainer.style.display = 'none'; // Hidden by default
// Add a toggle button
const toggleButton = document.createElement('button');
toggleButton.className = 'timestamp-toggle';
toggleButton.textContent = 'Show Timestamps';
toggleButton.onclick = function() {
const isHidden = timestampsContainer.style.display === 'none';
timestampsContainer.style.display = isHidden ? 'block' : 'none';
toggleButton.textContent = isHidden ? 'Hide Timestamps' : 'Show Timestamps';
};
// Add timestamps for each segment
data.segments.forEach(segment => {
const timestampDiv = document.createElement('div');
timestampDiv.className = 'timestamp';
// Format start and end times
const startTime = formatTime(segment.start);
const endTime = formatTime(segment.end);
timestampDiv.innerHTML = `
<span class="time">[${startTime} - ${endTime}]</span>
<span class="text">${segment.text}</span>
`;
timestampsContainer.appendChild(timestampDiv);
});
// Add the timestamp elements to the message
messageDiv.appendChild(toggleButton);
messageDiv.appendChild(timestampsContainer);
}
return messageDiv;
}
// Helper function to format time in seconds to MM:SS.ms format
function formatTime(seconds) {
const mins = Math.floor(seconds / 60);
const secs = Math.floor(seconds % 60);
const ms = Math.floor((seconds % 1) * 1000);
return `${mins.toString().padStart(2, '0')}:${secs.toString().padStart(2, '0')}.${ms.toString().padStart(3, '0')}`;
}
// Handle context update from server
@@ -804,7 +856,7 @@ function finalizeStreamingAudio() {
// Add CSS styles for new UI elements
document.addEventListener('DOMContentLoaded', function() {
// Add styles for processing state
// Add styles for processing state and timestamps
const style = document.createElement('style');
style.textContent = `
.message.processing {
@@ -833,6 +885,44 @@ document.addEventListener('DOMContentLoaded', function() {
0% { transform: rotate(0deg); }
100% { transform: rotate(360deg); }
}
/* Timestamp styles */
.timestamp-toggle {
font-size: 0.75em;
padding: 4px 8px;
margin-top: 8px;
background-color: #f0f0f0;
border: 1px solid #ddd;
border-radius: 4px;
cursor: pointer;
}
.timestamp-toggle:hover {
background-color: #e0e0e0;
}
.timestamps-container {
margin-top: 8px;
padding: 8px;
background-color: #f9f9f9;
border-radius: 4px;
font-size: 0.85em;
}
.timestamp {
margin-bottom: 4px;
padding: 2px 0;
}
.timestamp .time {
color: #666;
font-family: monospace;
margin-right: 8px;
}
.timestamp .text {
color: #333;
}
`;
document.head.appendChild(style);
});