Demo Update 18
This commit is contained in:
@@ -8,6 +8,7 @@ import logging
|
||||
import threading
|
||||
import queue
|
||||
import tempfile
|
||||
import gc
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
@@ -18,6 +19,9 @@ from flask_socketio import SocketIO, emit
|
||||
from flask_cors import CORS
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
|
||||
|
||||
# Import WhisperX for better transcription
|
||||
import whisperx
|
||||
|
||||
from generator import load_csm_1b, Segment
|
||||
from dataclasses import dataclass
|
||||
|
||||
@@ -52,7 +56,10 @@ class AppModels:
|
||||
generator = None
|
||||
tokenizer = None
|
||||
llm = None
|
||||
asr = None
|
||||
whisperx_model = None
|
||||
whisperx_align_model = None
|
||||
whisperx_align_metadata = None
|
||||
diarize_model = None
|
||||
|
||||
models = AppModels()
|
||||
|
||||
@@ -69,22 +76,16 @@ def load_models():
|
||||
logger.error(f"Error loading CSM 1B model: {str(e)}")
|
||||
socketio.emit('model_status', {'model': 'csm', 'status': 'error', 'message': str(e)})
|
||||
|
||||
logger.info("Loading ASR pipeline...")
|
||||
logger.info("Loading WhisperX model...")
|
||||
try:
|
||||
# Initialize the pipeline without the language parameter in the constructor
|
||||
models.asr = pipeline(
|
||||
"automatic-speech-recognition",
|
||||
model="openai/whisper-small",
|
||||
device=DEVICE
|
||||
)
|
||||
|
||||
# Configure the model with the appropriate options
|
||||
# Note that for whisper, language should be set during inference, not initialization
|
||||
logger.info("ASR pipeline loaded successfully")
|
||||
socketio.emit('model_status', {'model': 'asr', 'status': 'loaded'})
|
||||
# Use WhisperX instead of the regular Whisper
|
||||
compute_type = "float16" if DEVICE == "cuda" else "float32"
|
||||
models.whisperx_model = whisperx.load_model("large-v2", DEVICE, compute_type=compute_type)
|
||||
logger.info("WhisperX model loaded successfully")
|
||||
socketio.emit('model_status', {'model': 'whisperx', 'status': 'loaded'})
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading ASR pipeline: {str(e)}")
|
||||
socketio.emit('model_status', {'model': 'asr', 'status': 'error', 'message': str(e)})
|
||||
logger.error(f"Error loading WhisperX model: {str(e)}")
|
||||
socketio.emit('model_status', {'model': 'whisperx', 'status': 'error', 'message': str(e)})
|
||||
|
||||
logger.info("Loading Llama 3.2 model...")
|
||||
try:
|
||||
@@ -149,7 +150,7 @@ def system_status():
|
||||
"device": DEVICE,
|
||||
"models": {
|
||||
"generator": models.generator is not None,
|
||||
"asr": models.asr is not None,
|
||||
"whisperx": models.whisperx_model is not None,
|
||||
"llm": models.llm is not None
|
||||
}
|
||||
})
|
||||
@@ -259,8 +260,8 @@ def process_audio_queue(session_id, q):
|
||||
del user_queues[session_id]
|
||||
|
||||
def process_audio_and_respond(session_id, data):
|
||||
"""Process audio data and generate a response"""
|
||||
if models.generator is None or models.asr is None or models.llm is None:
|
||||
"""Process audio data and generate a response using WhisperX"""
|
||||
if models.generator is None or models.whisperx_model is None or models.llm is None:
|
||||
logger.warning("Models not yet loaded!")
|
||||
with app.app_context():
|
||||
socketio.emit('error', {'message': 'Models still loading, please wait'}, room=session_id)
|
||||
@@ -292,7 +293,57 @@ def process_audio_and_respond(session_id, data):
|
||||
temp_path = temp_file.name
|
||||
|
||||
try:
|
||||
# Load audio file
|
||||
# Load audio using WhisperX
|
||||
with app.app_context():
|
||||
socketio.emit('processing_status', {'status': 'transcribing'}, room=session_id)
|
||||
|
||||
# Load audio with WhisperX instead of torchaudio
|
||||
audio = whisperx.load_audio(temp_path)
|
||||
|
||||
# Transcribe using WhisperX
|
||||
batch_size = 16 # Adjust based on available memory
|
||||
result = models.whisperx_model.transcribe(audio, batch_size=batch_size)
|
||||
|
||||
# Get the detected language
|
||||
language_code = result["language"]
|
||||
logger.info(f"Detected language: {language_code}")
|
||||
|
||||
# Load alignment model if not already loaded
|
||||
if models.whisperx_align_model is None or language_code != getattr(models, 'last_language', None):
|
||||
# Clear previous models to save memory
|
||||
if models.whisperx_align_model is not None:
|
||||
del models.whisperx_align_model
|
||||
del models.whisperx_align_metadata
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache() if DEVICE == "cuda" else None
|
||||
|
||||
models.whisperx_align_model, models.whisperx_align_metadata = whisperx.load_align_model(
|
||||
language_code=language_code, device=DEVICE
|
||||
)
|
||||
models.last_language = language_code
|
||||
|
||||
# Align the transcript
|
||||
result = whisperx.align(
|
||||
result["segments"],
|
||||
models.whisperx_align_model,
|
||||
models.whisperx_align_metadata,
|
||||
audio,
|
||||
DEVICE,
|
||||
return_char_alignments=False
|
||||
)
|
||||
|
||||
# Combine all segments into a single transcript
|
||||
user_text = ' '.join([segment['text'] for segment in result['segments']])
|
||||
|
||||
# If no text was recognized, don't process further
|
||||
if not user_text or len(user_text.strip()) == 0:
|
||||
with app.app_context():
|
||||
socketio.emit('error', {'message': 'No speech detected'}, room=session_id)
|
||||
return
|
||||
|
||||
logger.info(f"Transcription: {user_text}")
|
||||
|
||||
# Load audio for CSM input
|
||||
waveform, sample_rate = torchaudio.load(temp_path)
|
||||
|
||||
# Normalize to mono if needed
|
||||
@@ -307,24 +358,6 @@ def process_audio_and_respond(session_id, data):
|
||||
new_freq=models.generator.sample_rate
|
||||
)
|
||||
|
||||
# Transcribe audio
|
||||
with app.app_context():
|
||||
socketio.emit('processing_status', {'status': 'transcribing'}, room=session_id)
|
||||
|
||||
# Use the ASR pipeline to transcribe
|
||||
transcription_result = models.asr(
|
||||
{"array": waveform.squeeze().cpu().numpy(), "sampling_rate": models.generator.sample_rate},
|
||||
return_timestamps=False,
|
||||
generate_kwargs={"language": "en"} # Set language during inference
|
||||
)
|
||||
user_text = transcription_result['text'].strip()
|
||||
|
||||
# If no text was recognized, don't process further
|
||||
if not user_text:
|
||||
with app.app_context():
|
||||
socketio.emit('error', {'message': 'No speech detected'}, room=session_id)
|
||||
return
|
||||
|
||||
# Add the user's message to conversation history
|
||||
user_segment = conversation.add_segment(
|
||||
text=user_text,
|
||||
@@ -336,7 +369,8 @@ def process_audio_and_respond(session_id, data):
|
||||
with app.app_context():
|
||||
socketio.emit('transcription', {
|
||||
'text': user_text,
|
||||
'speaker': speaker_id
|
||||
'speaker': speaker_id,
|
||||
'segments': result['segments'] # Send detailed segments info
|
||||
}, room=session_id)
|
||||
|
||||
# Generate AI response using Llama
|
||||
@@ -365,7 +399,7 @@ def process_audio_and_respond(session_id, data):
|
||||
with torch.no_grad():
|
||||
generated_ids = models.llm.generate(
|
||||
input_ids,
|
||||
attention_mask=attention_mask, # Add the attention mask
|
||||
attention_mask=attention_mask,
|
||||
max_new_tokens=100,
|
||||
temperature=0.7,
|
||||
top_p=0.9,
|
||||
@@ -390,13 +424,13 @@ def process_audio_and_respond(session_id, data):
|
||||
'chunk_index': 0
|
||||
}, room=session_id)
|
||||
|
||||
# Define AI speaker ID (use a consistent value for the AI's voice)
|
||||
ai_speaker_id = 1 # Use speaker 1 for AI responses
|
||||
# Define AI speaker ID
|
||||
ai_speaker_id = conversation.ai_speaker_id
|
||||
|
||||
# Generate audio
|
||||
audio_tensor = models.generator.generate(
|
||||
text=response_text,
|
||||
speaker=ai_speaker_id, # Use the local variable instead of conversation.ai_speaker_id
|
||||
speaker=ai_speaker_id,
|
||||
context=conversation.get_context(),
|
||||
max_audio_length_ms=10_000,
|
||||
temperature=0.9
|
||||
@@ -405,7 +439,7 @@ def process_audio_and_respond(session_id, data):
|
||||
# Add AI response to conversation history
|
||||
ai_segment = conversation.add_segment(
|
||||
text=response_text,
|
||||
speaker=ai_speaker_id, # Also use the local variable here
|
||||
speaker=ai_speaker_id,
|
||||
audio=audio_tensor
|
||||
)
|
||||
|
||||
@@ -444,6 +478,8 @@ def process_audio_and_respond(session_id, data):
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing audio: {str(e)}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
with app.app_context():
|
||||
socketio.emit('error', {'message': f'Error: {str(e)}'}, room=session_id)
|
||||
finally:
|
||||
|
||||
@@ -627,7 +627,59 @@ function addSystemMessage(text) {
|
||||
// Handle transcription response from server
|
||||
function handleTranscription(data) {
|
||||
const speaker = data.speaker === 0 ? 'user' : 'ai';
|
||||
addMessage(data.text, speaker);
|
||||
|
||||
// Create the message div
|
||||
const messageDiv = addMessage(data.text, speaker);
|
||||
|
||||
// If we have detailed segments from WhisperX, add timestamps
|
||||
if (data.segments && data.segments.length > 0) {
|
||||
// Add a timestamps container
|
||||
const timestampsContainer = document.createElement('div');
|
||||
timestampsContainer.className = 'timestamps-container';
|
||||
timestampsContainer.style.display = 'none'; // Hidden by default
|
||||
|
||||
// Add a toggle button
|
||||
const toggleButton = document.createElement('button');
|
||||
toggleButton.className = 'timestamp-toggle';
|
||||
toggleButton.textContent = 'Show Timestamps';
|
||||
toggleButton.onclick = function() {
|
||||
const isHidden = timestampsContainer.style.display === 'none';
|
||||
timestampsContainer.style.display = isHidden ? 'block' : 'none';
|
||||
toggleButton.textContent = isHidden ? 'Hide Timestamps' : 'Show Timestamps';
|
||||
};
|
||||
|
||||
// Add timestamps for each segment
|
||||
data.segments.forEach(segment => {
|
||||
const timestampDiv = document.createElement('div');
|
||||
timestampDiv.className = 'timestamp';
|
||||
|
||||
// Format start and end times
|
||||
const startTime = formatTime(segment.start);
|
||||
const endTime = formatTime(segment.end);
|
||||
|
||||
timestampDiv.innerHTML = `
|
||||
<span class="time">[${startTime} - ${endTime}]</span>
|
||||
<span class="text">${segment.text}</span>
|
||||
`;
|
||||
|
||||
timestampsContainer.appendChild(timestampDiv);
|
||||
});
|
||||
|
||||
// Add the timestamp elements to the message
|
||||
messageDiv.appendChild(toggleButton);
|
||||
messageDiv.appendChild(timestampsContainer);
|
||||
}
|
||||
|
||||
return messageDiv;
|
||||
}
|
||||
|
||||
// Helper function to format time in seconds to MM:SS.ms format
|
||||
function formatTime(seconds) {
|
||||
const mins = Math.floor(seconds / 60);
|
||||
const secs = Math.floor(seconds % 60);
|
||||
const ms = Math.floor((seconds % 1) * 1000);
|
||||
|
||||
return `${mins.toString().padStart(2, '0')}:${secs.toString().padStart(2, '0')}.${ms.toString().padStart(3, '0')}`;
|
||||
}
|
||||
|
||||
// Handle context update from server
|
||||
@@ -804,7 +856,7 @@ function finalizeStreamingAudio() {
|
||||
|
||||
// Add CSS styles for new UI elements
|
||||
document.addEventListener('DOMContentLoaded', function() {
|
||||
// Add styles for processing state
|
||||
// Add styles for processing state and timestamps
|
||||
const style = document.createElement('style');
|
||||
style.textContent = `
|
||||
.message.processing {
|
||||
@@ -833,6 +885,44 @@ document.addEventListener('DOMContentLoaded', function() {
|
||||
0% { transform: rotate(0deg); }
|
||||
100% { transform: rotate(360deg); }
|
||||
}
|
||||
|
||||
/* Timestamp styles */
|
||||
.timestamp-toggle {
|
||||
font-size: 0.75em;
|
||||
padding: 4px 8px;
|
||||
margin-top: 8px;
|
||||
background-color: #f0f0f0;
|
||||
border: 1px solid #ddd;
|
||||
border-radius: 4px;
|
||||
cursor: pointer;
|
||||
}
|
||||
|
||||
.timestamp-toggle:hover {
|
||||
background-color: #e0e0e0;
|
||||
}
|
||||
|
||||
.timestamps-container {
|
||||
margin-top: 8px;
|
||||
padding: 8px;
|
||||
background-color: #f9f9f9;
|
||||
border-radius: 4px;
|
||||
font-size: 0.85em;
|
||||
}
|
||||
|
||||
.timestamp {
|
||||
margin-bottom: 4px;
|
||||
padding: 2px 0;
|
||||
}
|
||||
|
||||
.timestamp .time {
|
||||
color: #666;
|
||||
font-family: monospace;
|
||||
margin-right: 8px;
|
||||
}
|
||||
|
||||
.timestamp .text {
|
||||
color: #333;
|
||||
}
|
||||
`;
|
||||
document.head.appendChild(style);
|
||||
});
|
||||
|
||||
Reference in New Issue
Block a user