Demo Update 18

This commit is contained in:
2025-03-30 02:40:08 -04:00
parent faacb612e7
commit 10902f1d71
2 changed files with 171 additions and 45 deletions

View File

@@ -8,6 +8,7 @@ import logging
import threading import threading
import queue import queue
import tempfile import tempfile
import gc
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
import torch import torch
@@ -18,6 +19,9 @@ from flask_socketio import SocketIO, emit
from flask_cors import CORS from flask_cors import CORS
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
# Import WhisperX for better transcription
import whisperx
from generator import load_csm_1b, Segment from generator import load_csm_1b, Segment
from dataclasses import dataclass from dataclasses import dataclass
@@ -52,7 +56,10 @@ class AppModels:
generator = None generator = None
tokenizer = None tokenizer = None
llm = None llm = None
asr = None whisperx_model = None
whisperx_align_model = None
whisperx_align_metadata = None
diarize_model = None
models = AppModels() models = AppModels()
@@ -69,22 +76,16 @@ def load_models():
logger.error(f"Error loading CSM 1B model: {str(e)}") logger.error(f"Error loading CSM 1B model: {str(e)}")
socketio.emit('model_status', {'model': 'csm', 'status': 'error', 'message': str(e)}) socketio.emit('model_status', {'model': 'csm', 'status': 'error', 'message': str(e)})
logger.info("Loading ASR pipeline...") logger.info("Loading WhisperX model...")
try: try:
# Initialize the pipeline without the language parameter in the constructor # Use WhisperX instead of the regular Whisper
models.asr = pipeline( compute_type = "float16" if DEVICE == "cuda" else "float32"
"automatic-speech-recognition", models.whisperx_model = whisperx.load_model("large-v2", DEVICE, compute_type=compute_type)
model="openai/whisper-small", logger.info("WhisperX model loaded successfully")
device=DEVICE socketio.emit('model_status', {'model': 'whisperx', 'status': 'loaded'})
)
# Configure the model with the appropriate options
# Note that for whisper, language should be set during inference, not initialization
logger.info("ASR pipeline loaded successfully")
socketio.emit('model_status', {'model': 'asr', 'status': 'loaded'})
except Exception as e: except Exception as e:
logger.error(f"Error loading ASR pipeline: {str(e)}") logger.error(f"Error loading WhisperX model: {str(e)}")
socketio.emit('model_status', {'model': 'asr', 'status': 'error', 'message': str(e)}) socketio.emit('model_status', {'model': 'whisperx', 'status': 'error', 'message': str(e)})
logger.info("Loading Llama 3.2 model...") logger.info("Loading Llama 3.2 model...")
try: try:
@@ -149,7 +150,7 @@ def system_status():
"device": DEVICE, "device": DEVICE,
"models": { "models": {
"generator": models.generator is not None, "generator": models.generator is not None,
"asr": models.asr is not None, "whisperx": models.whisperx_model is not None,
"llm": models.llm is not None "llm": models.llm is not None
} }
}) })
@@ -259,8 +260,8 @@ def process_audio_queue(session_id, q):
del user_queues[session_id] del user_queues[session_id]
def process_audio_and_respond(session_id, data): def process_audio_and_respond(session_id, data):
"""Process audio data and generate a response""" """Process audio data and generate a response using WhisperX"""
if models.generator is None or models.asr is None or models.llm is None: if models.generator is None or models.whisperx_model is None or models.llm is None:
logger.warning("Models not yet loaded!") logger.warning("Models not yet loaded!")
with app.app_context(): with app.app_context():
socketio.emit('error', {'message': 'Models still loading, please wait'}, room=session_id) socketio.emit('error', {'message': 'Models still loading, please wait'}, room=session_id)
@@ -292,7 +293,57 @@ def process_audio_and_respond(session_id, data):
temp_path = temp_file.name temp_path = temp_file.name
try: try:
# Load audio file # Load audio using WhisperX
with app.app_context():
socketio.emit('processing_status', {'status': 'transcribing'}, room=session_id)
# Load audio with WhisperX instead of torchaudio
audio = whisperx.load_audio(temp_path)
# Transcribe using WhisperX
batch_size = 16 # Adjust based on available memory
result = models.whisperx_model.transcribe(audio, batch_size=batch_size)
# Get the detected language
language_code = result["language"]
logger.info(f"Detected language: {language_code}")
# Load alignment model if not already loaded
if models.whisperx_align_model is None or language_code != getattr(models, 'last_language', None):
# Clear previous models to save memory
if models.whisperx_align_model is not None:
del models.whisperx_align_model
del models.whisperx_align_metadata
gc.collect()
torch.cuda.empty_cache() if DEVICE == "cuda" else None
models.whisperx_align_model, models.whisperx_align_metadata = whisperx.load_align_model(
language_code=language_code, device=DEVICE
)
models.last_language = language_code
# Align the transcript
result = whisperx.align(
result["segments"],
models.whisperx_align_model,
models.whisperx_align_metadata,
audio,
DEVICE,
return_char_alignments=False
)
# Combine all segments into a single transcript
user_text = ' '.join([segment['text'] for segment in result['segments']])
# If no text was recognized, don't process further
if not user_text or len(user_text.strip()) == 0:
with app.app_context():
socketio.emit('error', {'message': 'No speech detected'}, room=session_id)
return
logger.info(f"Transcription: {user_text}")
# Load audio for CSM input
waveform, sample_rate = torchaudio.load(temp_path) waveform, sample_rate = torchaudio.load(temp_path)
# Normalize to mono if needed # Normalize to mono if needed
@@ -307,24 +358,6 @@ def process_audio_and_respond(session_id, data):
new_freq=models.generator.sample_rate new_freq=models.generator.sample_rate
) )
# Transcribe audio
with app.app_context():
socketio.emit('processing_status', {'status': 'transcribing'}, room=session_id)
# Use the ASR pipeline to transcribe
transcription_result = models.asr(
{"array": waveform.squeeze().cpu().numpy(), "sampling_rate": models.generator.sample_rate},
return_timestamps=False,
generate_kwargs={"language": "en"} # Set language during inference
)
user_text = transcription_result['text'].strip()
# If no text was recognized, don't process further
if not user_text:
with app.app_context():
socketio.emit('error', {'message': 'No speech detected'}, room=session_id)
return
# Add the user's message to conversation history # Add the user's message to conversation history
user_segment = conversation.add_segment( user_segment = conversation.add_segment(
text=user_text, text=user_text,
@@ -336,7 +369,8 @@ def process_audio_and_respond(session_id, data):
with app.app_context(): with app.app_context():
socketio.emit('transcription', { socketio.emit('transcription', {
'text': user_text, 'text': user_text,
'speaker': speaker_id 'speaker': speaker_id,
'segments': result['segments'] # Send detailed segments info
}, room=session_id) }, room=session_id)
# Generate AI response using Llama # Generate AI response using Llama
@@ -365,7 +399,7 @@ def process_audio_and_respond(session_id, data):
with torch.no_grad(): with torch.no_grad():
generated_ids = models.llm.generate( generated_ids = models.llm.generate(
input_ids, input_ids,
attention_mask=attention_mask, # Add the attention mask attention_mask=attention_mask,
max_new_tokens=100, max_new_tokens=100,
temperature=0.7, temperature=0.7,
top_p=0.9, top_p=0.9,
@@ -390,13 +424,13 @@ def process_audio_and_respond(session_id, data):
'chunk_index': 0 'chunk_index': 0
}, room=session_id) }, room=session_id)
# Define AI speaker ID (use a consistent value for the AI's voice) # Define AI speaker ID
ai_speaker_id = 1 # Use speaker 1 for AI responses ai_speaker_id = conversation.ai_speaker_id
# Generate audio # Generate audio
audio_tensor = models.generator.generate( audio_tensor = models.generator.generate(
text=response_text, text=response_text,
speaker=ai_speaker_id, # Use the local variable instead of conversation.ai_speaker_id speaker=ai_speaker_id,
context=conversation.get_context(), context=conversation.get_context(),
max_audio_length_ms=10_000, max_audio_length_ms=10_000,
temperature=0.9 temperature=0.9
@@ -405,7 +439,7 @@ def process_audio_and_respond(session_id, data):
# Add AI response to conversation history # Add AI response to conversation history
ai_segment = conversation.add_segment( ai_segment = conversation.add_segment(
text=response_text, text=response_text,
speaker=ai_speaker_id, # Also use the local variable here speaker=ai_speaker_id,
audio=audio_tensor audio=audio_tensor
) )
@@ -444,6 +478,8 @@ def process_audio_and_respond(session_id, data):
except Exception as e: except Exception as e:
logger.error(f"Error processing audio: {str(e)}") logger.error(f"Error processing audio: {str(e)}")
import traceback
logger.error(traceback.format_exc())
with app.app_context(): with app.app_context():
socketio.emit('error', {'message': f'Error: {str(e)}'}, room=session_id) socketio.emit('error', {'message': f'Error: {str(e)}'}, room=session_id)
finally: finally:

View File

@@ -627,7 +627,59 @@ function addSystemMessage(text) {
// Handle transcription response from server // Handle transcription response from server
function handleTranscription(data) { function handleTranscription(data) {
const speaker = data.speaker === 0 ? 'user' : 'ai'; const speaker = data.speaker === 0 ? 'user' : 'ai';
addMessage(data.text, speaker);
// Create the message div
const messageDiv = addMessage(data.text, speaker);
// If we have detailed segments from WhisperX, add timestamps
if (data.segments && data.segments.length > 0) {
// Add a timestamps container
const timestampsContainer = document.createElement('div');
timestampsContainer.className = 'timestamps-container';
timestampsContainer.style.display = 'none'; // Hidden by default
// Add a toggle button
const toggleButton = document.createElement('button');
toggleButton.className = 'timestamp-toggle';
toggleButton.textContent = 'Show Timestamps';
toggleButton.onclick = function() {
const isHidden = timestampsContainer.style.display === 'none';
timestampsContainer.style.display = isHidden ? 'block' : 'none';
toggleButton.textContent = isHidden ? 'Hide Timestamps' : 'Show Timestamps';
};
// Add timestamps for each segment
data.segments.forEach(segment => {
const timestampDiv = document.createElement('div');
timestampDiv.className = 'timestamp';
// Format start and end times
const startTime = formatTime(segment.start);
const endTime = formatTime(segment.end);
timestampDiv.innerHTML = `
<span class="time">[${startTime} - ${endTime}]</span>
<span class="text">${segment.text}</span>
`;
timestampsContainer.appendChild(timestampDiv);
});
// Add the timestamp elements to the message
messageDiv.appendChild(toggleButton);
messageDiv.appendChild(timestampsContainer);
}
return messageDiv;
}
// Helper function to format time in seconds to MM:SS.ms format
function formatTime(seconds) {
const mins = Math.floor(seconds / 60);
const secs = Math.floor(seconds % 60);
const ms = Math.floor((seconds % 1) * 1000);
return `${mins.toString().padStart(2, '0')}:${secs.toString().padStart(2, '0')}.${ms.toString().padStart(3, '0')}`;
} }
// Handle context update from server // Handle context update from server
@@ -804,7 +856,7 @@ function finalizeStreamingAudio() {
// Add CSS styles for new UI elements // Add CSS styles for new UI elements
document.addEventListener('DOMContentLoaded', function() { document.addEventListener('DOMContentLoaded', function() {
// Add styles for processing state // Add styles for processing state and timestamps
const style = document.createElement('style'); const style = document.createElement('style');
style.textContent = ` style.textContent = `
.message.processing { .message.processing {
@@ -833,6 +885,44 @@ document.addEventListener('DOMContentLoaded', function() {
0% { transform: rotate(0deg); } 0% { transform: rotate(0deg); }
100% { transform: rotate(360deg); } 100% { transform: rotate(360deg); }
} }
/* Timestamp styles */
.timestamp-toggle {
font-size: 0.75em;
padding: 4px 8px;
margin-top: 8px;
background-color: #f0f0f0;
border: 1px solid #ddd;
border-radius: 4px;
cursor: pointer;
}
.timestamp-toggle:hover {
background-color: #e0e0e0;
}
.timestamps-container {
margin-top: 8px;
padding: 8px;
background-color: #f9f9f9;
border-radius: 4px;
font-size: 0.85em;
}
.timestamp {
margin-bottom: 4px;
padding: 2px 0;
}
.timestamp .time {
color: #666;
font-family: monospace;
margin-right: 8px;
}
.timestamp .text {
color: #333;
}
`; `;
document.head.appendChild(style); document.head.appendChild(style);
}); });