Demo Update 18
This commit is contained in:
@@ -8,6 +8,7 @@ import logging
|
|||||||
import threading
|
import threading
|
||||||
import queue
|
import queue
|
||||||
import tempfile
|
import tempfile
|
||||||
|
import gc
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -18,6 +19,9 @@ from flask_socketio import SocketIO, emit
|
|||||||
from flask_cors import CORS
|
from flask_cors import CORS
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
|
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
|
||||||
|
|
||||||
|
# Import WhisperX for better transcription
|
||||||
|
import whisperx
|
||||||
|
|
||||||
from generator import load_csm_1b, Segment
|
from generator import load_csm_1b, Segment
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
@@ -52,7 +56,10 @@ class AppModels:
|
|||||||
generator = None
|
generator = None
|
||||||
tokenizer = None
|
tokenizer = None
|
||||||
llm = None
|
llm = None
|
||||||
asr = None
|
whisperx_model = None
|
||||||
|
whisperx_align_model = None
|
||||||
|
whisperx_align_metadata = None
|
||||||
|
diarize_model = None
|
||||||
|
|
||||||
models = AppModels()
|
models = AppModels()
|
||||||
|
|
||||||
@@ -69,22 +76,16 @@ def load_models():
|
|||||||
logger.error(f"Error loading CSM 1B model: {str(e)}")
|
logger.error(f"Error loading CSM 1B model: {str(e)}")
|
||||||
socketio.emit('model_status', {'model': 'csm', 'status': 'error', 'message': str(e)})
|
socketio.emit('model_status', {'model': 'csm', 'status': 'error', 'message': str(e)})
|
||||||
|
|
||||||
logger.info("Loading ASR pipeline...")
|
logger.info("Loading WhisperX model...")
|
||||||
try:
|
try:
|
||||||
# Initialize the pipeline without the language parameter in the constructor
|
# Use WhisperX instead of the regular Whisper
|
||||||
models.asr = pipeline(
|
compute_type = "float16" if DEVICE == "cuda" else "float32"
|
||||||
"automatic-speech-recognition",
|
models.whisperx_model = whisperx.load_model("large-v2", DEVICE, compute_type=compute_type)
|
||||||
model="openai/whisper-small",
|
logger.info("WhisperX model loaded successfully")
|
||||||
device=DEVICE
|
socketio.emit('model_status', {'model': 'whisperx', 'status': 'loaded'})
|
||||||
)
|
|
||||||
|
|
||||||
# Configure the model with the appropriate options
|
|
||||||
# Note that for whisper, language should be set during inference, not initialization
|
|
||||||
logger.info("ASR pipeline loaded successfully")
|
|
||||||
socketio.emit('model_status', {'model': 'asr', 'status': 'loaded'})
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error loading ASR pipeline: {str(e)}")
|
logger.error(f"Error loading WhisperX model: {str(e)}")
|
||||||
socketio.emit('model_status', {'model': 'asr', 'status': 'error', 'message': str(e)})
|
socketio.emit('model_status', {'model': 'whisperx', 'status': 'error', 'message': str(e)})
|
||||||
|
|
||||||
logger.info("Loading Llama 3.2 model...")
|
logger.info("Loading Llama 3.2 model...")
|
||||||
try:
|
try:
|
||||||
@@ -149,7 +150,7 @@ def system_status():
|
|||||||
"device": DEVICE,
|
"device": DEVICE,
|
||||||
"models": {
|
"models": {
|
||||||
"generator": models.generator is not None,
|
"generator": models.generator is not None,
|
||||||
"asr": models.asr is not None,
|
"whisperx": models.whisperx_model is not None,
|
||||||
"llm": models.llm is not None
|
"llm": models.llm is not None
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
@@ -259,8 +260,8 @@ def process_audio_queue(session_id, q):
|
|||||||
del user_queues[session_id]
|
del user_queues[session_id]
|
||||||
|
|
||||||
def process_audio_and_respond(session_id, data):
|
def process_audio_and_respond(session_id, data):
|
||||||
"""Process audio data and generate a response"""
|
"""Process audio data and generate a response using WhisperX"""
|
||||||
if models.generator is None or models.asr is None or models.llm is None:
|
if models.generator is None or models.whisperx_model is None or models.llm is None:
|
||||||
logger.warning("Models not yet loaded!")
|
logger.warning("Models not yet loaded!")
|
||||||
with app.app_context():
|
with app.app_context():
|
||||||
socketio.emit('error', {'message': 'Models still loading, please wait'}, room=session_id)
|
socketio.emit('error', {'message': 'Models still loading, please wait'}, room=session_id)
|
||||||
@@ -292,7 +293,57 @@ def process_audio_and_respond(session_id, data):
|
|||||||
temp_path = temp_file.name
|
temp_path = temp_file.name
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Load audio file
|
# Load audio using WhisperX
|
||||||
|
with app.app_context():
|
||||||
|
socketio.emit('processing_status', {'status': 'transcribing'}, room=session_id)
|
||||||
|
|
||||||
|
# Load audio with WhisperX instead of torchaudio
|
||||||
|
audio = whisperx.load_audio(temp_path)
|
||||||
|
|
||||||
|
# Transcribe using WhisperX
|
||||||
|
batch_size = 16 # Adjust based on available memory
|
||||||
|
result = models.whisperx_model.transcribe(audio, batch_size=batch_size)
|
||||||
|
|
||||||
|
# Get the detected language
|
||||||
|
language_code = result["language"]
|
||||||
|
logger.info(f"Detected language: {language_code}")
|
||||||
|
|
||||||
|
# Load alignment model if not already loaded
|
||||||
|
if models.whisperx_align_model is None or language_code != getattr(models, 'last_language', None):
|
||||||
|
# Clear previous models to save memory
|
||||||
|
if models.whisperx_align_model is not None:
|
||||||
|
del models.whisperx_align_model
|
||||||
|
del models.whisperx_align_metadata
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache() if DEVICE == "cuda" else None
|
||||||
|
|
||||||
|
models.whisperx_align_model, models.whisperx_align_metadata = whisperx.load_align_model(
|
||||||
|
language_code=language_code, device=DEVICE
|
||||||
|
)
|
||||||
|
models.last_language = language_code
|
||||||
|
|
||||||
|
# Align the transcript
|
||||||
|
result = whisperx.align(
|
||||||
|
result["segments"],
|
||||||
|
models.whisperx_align_model,
|
||||||
|
models.whisperx_align_metadata,
|
||||||
|
audio,
|
||||||
|
DEVICE,
|
||||||
|
return_char_alignments=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# Combine all segments into a single transcript
|
||||||
|
user_text = ' '.join([segment['text'] for segment in result['segments']])
|
||||||
|
|
||||||
|
# If no text was recognized, don't process further
|
||||||
|
if not user_text or len(user_text.strip()) == 0:
|
||||||
|
with app.app_context():
|
||||||
|
socketio.emit('error', {'message': 'No speech detected'}, room=session_id)
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(f"Transcription: {user_text}")
|
||||||
|
|
||||||
|
# Load audio for CSM input
|
||||||
waveform, sample_rate = torchaudio.load(temp_path)
|
waveform, sample_rate = torchaudio.load(temp_path)
|
||||||
|
|
||||||
# Normalize to mono if needed
|
# Normalize to mono if needed
|
||||||
@@ -307,24 +358,6 @@ def process_audio_and_respond(session_id, data):
|
|||||||
new_freq=models.generator.sample_rate
|
new_freq=models.generator.sample_rate
|
||||||
)
|
)
|
||||||
|
|
||||||
# Transcribe audio
|
|
||||||
with app.app_context():
|
|
||||||
socketio.emit('processing_status', {'status': 'transcribing'}, room=session_id)
|
|
||||||
|
|
||||||
# Use the ASR pipeline to transcribe
|
|
||||||
transcription_result = models.asr(
|
|
||||||
{"array": waveform.squeeze().cpu().numpy(), "sampling_rate": models.generator.sample_rate},
|
|
||||||
return_timestamps=False,
|
|
||||||
generate_kwargs={"language": "en"} # Set language during inference
|
|
||||||
)
|
|
||||||
user_text = transcription_result['text'].strip()
|
|
||||||
|
|
||||||
# If no text was recognized, don't process further
|
|
||||||
if not user_text:
|
|
||||||
with app.app_context():
|
|
||||||
socketio.emit('error', {'message': 'No speech detected'}, room=session_id)
|
|
||||||
return
|
|
||||||
|
|
||||||
# Add the user's message to conversation history
|
# Add the user's message to conversation history
|
||||||
user_segment = conversation.add_segment(
|
user_segment = conversation.add_segment(
|
||||||
text=user_text,
|
text=user_text,
|
||||||
@@ -336,7 +369,8 @@ def process_audio_and_respond(session_id, data):
|
|||||||
with app.app_context():
|
with app.app_context():
|
||||||
socketio.emit('transcription', {
|
socketio.emit('transcription', {
|
||||||
'text': user_text,
|
'text': user_text,
|
||||||
'speaker': speaker_id
|
'speaker': speaker_id,
|
||||||
|
'segments': result['segments'] # Send detailed segments info
|
||||||
}, room=session_id)
|
}, room=session_id)
|
||||||
|
|
||||||
# Generate AI response using Llama
|
# Generate AI response using Llama
|
||||||
@@ -365,7 +399,7 @@ def process_audio_and_respond(session_id, data):
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
generated_ids = models.llm.generate(
|
generated_ids = models.llm.generate(
|
||||||
input_ids,
|
input_ids,
|
||||||
attention_mask=attention_mask, # Add the attention mask
|
attention_mask=attention_mask,
|
||||||
max_new_tokens=100,
|
max_new_tokens=100,
|
||||||
temperature=0.7,
|
temperature=0.7,
|
||||||
top_p=0.9,
|
top_p=0.9,
|
||||||
@@ -390,13 +424,13 @@ def process_audio_and_respond(session_id, data):
|
|||||||
'chunk_index': 0
|
'chunk_index': 0
|
||||||
}, room=session_id)
|
}, room=session_id)
|
||||||
|
|
||||||
# Define AI speaker ID (use a consistent value for the AI's voice)
|
# Define AI speaker ID
|
||||||
ai_speaker_id = 1 # Use speaker 1 for AI responses
|
ai_speaker_id = conversation.ai_speaker_id
|
||||||
|
|
||||||
# Generate audio
|
# Generate audio
|
||||||
audio_tensor = models.generator.generate(
|
audio_tensor = models.generator.generate(
|
||||||
text=response_text,
|
text=response_text,
|
||||||
speaker=ai_speaker_id, # Use the local variable instead of conversation.ai_speaker_id
|
speaker=ai_speaker_id,
|
||||||
context=conversation.get_context(),
|
context=conversation.get_context(),
|
||||||
max_audio_length_ms=10_000,
|
max_audio_length_ms=10_000,
|
||||||
temperature=0.9
|
temperature=0.9
|
||||||
@@ -405,7 +439,7 @@ def process_audio_and_respond(session_id, data):
|
|||||||
# Add AI response to conversation history
|
# Add AI response to conversation history
|
||||||
ai_segment = conversation.add_segment(
|
ai_segment = conversation.add_segment(
|
||||||
text=response_text,
|
text=response_text,
|
||||||
speaker=ai_speaker_id, # Also use the local variable here
|
speaker=ai_speaker_id,
|
||||||
audio=audio_tensor
|
audio=audio_tensor
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -444,6 +478,8 @@ def process_audio_and_respond(session_id, data):
|
|||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error processing audio: {str(e)}")
|
logger.error(f"Error processing audio: {str(e)}")
|
||||||
|
import traceback
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
with app.app_context():
|
with app.app_context():
|
||||||
socketio.emit('error', {'message': f'Error: {str(e)}'}, room=session_id)
|
socketio.emit('error', {'message': f'Error: {str(e)}'}, room=session_id)
|
||||||
finally:
|
finally:
|
||||||
|
|||||||
@@ -627,7 +627,59 @@ function addSystemMessage(text) {
|
|||||||
// Handle transcription response from server
|
// Handle transcription response from server
|
||||||
function handleTranscription(data) {
|
function handleTranscription(data) {
|
||||||
const speaker = data.speaker === 0 ? 'user' : 'ai';
|
const speaker = data.speaker === 0 ? 'user' : 'ai';
|
||||||
addMessage(data.text, speaker);
|
|
||||||
|
// Create the message div
|
||||||
|
const messageDiv = addMessage(data.text, speaker);
|
||||||
|
|
||||||
|
// If we have detailed segments from WhisperX, add timestamps
|
||||||
|
if (data.segments && data.segments.length > 0) {
|
||||||
|
// Add a timestamps container
|
||||||
|
const timestampsContainer = document.createElement('div');
|
||||||
|
timestampsContainer.className = 'timestamps-container';
|
||||||
|
timestampsContainer.style.display = 'none'; // Hidden by default
|
||||||
|
|
||||||
|
// Add a toggle button
|
||||||
|
const toggleButton = document.createElement('button');
|
||||||
|
toggleButton.className = 'timestamp-toggle';
|
||||||
|
toggleButton.textContent = 'Show Timestamps';
|
||||||
|
toggleButton.onclick = function() {
|
||||||
|
const isHidden = timestampsContainer.style.display === 'none';
|
||||||
|
timestampsContainer.style.display = isHidden ? 'block' : 'none';
|
||||||
|
toggleButton.textContent = isHidden ? 'Hide Timestamps' : 'Show Timestamps';
|
||||||
|
};
|
||||||
|
|
||||||
|
// Add timestamps for each segment
|
||||||
|
data.segments.forEach(segment => {
|
||||||
|
const timestampDiv = document.createElement('div');
|
||||||
|
timestampDiv.className = 'timestamp';
|
||||||
|
|
||||||
|
// Format start and end times
|
||||||
|
const startTime = formatTime(segment.start);
|
||||||
|
const endTime = formatTime(segment.end);
|
||||||
|
|
||||||
|
timestampDiv.innerHTML = `
|
||||||
|
<span class="time">[${startTime} - ${endTime}]</span>
|
||||||
|
<span class="text">${segment.text}</span>
|
||||||
|
`;
|
||||||
|
|
||||||
|
timestampsContainer.appendChild(timestampDiv);
|
||||||
|
});
|
||||||
|
|
||||||
|
// Add the timestamp elements to the message
|
||||||
|
messageDiv.appendChild(toggleButton);
|
||||||
|
messageDiv.appendChild(timestampsContainer);
|
||||||
|
}
|
||||||
|
|
||||||
|
return messageDiv;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper function to format time in seconds to MM:SS.ms format
|
||||||
|
function formatTime(seconds) {
|
||||||
|
const mins = Math.floor(seconds / 60);
|
||||||
|
const secs = Math.floor(seconds % 60);
|
||||||
|
const ms = Math.floor((seconds % 1) * 1000);
|
||||||
|
|
||||||
|
return `${mins.toString().padStart(2, '0')}:${secs.toString().padStart(2, '0')}.${ms.toString().padStart(3, '0')}`;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle context update from server
|
// Handle context update from server
|
||||||
@@ -804,7 +856,7 @@ function finalizeStreamingAudio() {
|
|||||||
|
|
||||||
// Add CSS styles for new UI elements
|
// Add CSS styles for new UI elements
|
||||||
document.addEventListener('DOMContentLoaded', function() {
|
document.addEventListener('DOMContentLoaded', function() {
|
||||||
// Add styles for processing state
|
// Add styles for processing state and timestamps
|
||||||
const style = document.createElement('style');
|
const style = document.createElement('style');
|
||||||
style.textContent = `
|
style.textContent = `
|
||||||
.message.processing {
|
.message.processing {
|
||||||
@@ -833,6 +885,44 @@ document.addEventListener('DOMContentLoaded', function() {
|
|||||||
0% { transform: rotate(0deg); }
|
0% { transform: rotate(0deg); }
|
||||||
100% { transform: rotate(360deg); }
|
100% { transform: rotate(360deg); }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* Timestamp styles */
|
||||||
|
.timestamp-toggle {
|
||||||
|
font-size: 0.75em;
|
||||||
|
padding: 4px 8px;
|
||||||
|
margin-top: 8px;
|
||||||
|
background-color: #f0f0f0;
|
||||||
|
border: 1px solid #ddd;
|
||||||
|
border-radius: 4px;
|
||||||
|
cursor: pointer;
|
||||||
|
}
|
||||||
|
|
||||||
|
.timestamp-toggle:hover {
|
||||||
|
background-color: #e0e0e0;
|
||||||
|
}
|
||||||
|
|
||||||
|
.timestamps-container {
|
||||||
|
margin-top: 8px;
|
||||||
|
padding: 8px;
|
||||||
|
background-color: #f9f9f9;
|
||||||
|
border-radius: 4px;
|
||||||
|
font-size: 0.85em;
|
||||||
|
}
|
||||||
|
|
||||||
|
.timestamp {
|
||||||
|
margin-bottom: 4px;
|
||||||
|
padding: 2px 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
.timestamp .time {
|
||||||
|
color: #666;
|
||||||
|
font-family: monospace;
|
||||||
|
margin-right: 8px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.timestamp .text {
|
||||||
|
color: #333;
|
||||||
|
}
|
||||||
`;
|
`;
|
||||||
document.head.appendChild(style);
|
document.head.appendChild(style);
|
||||||
});
|
});
|
||||||
|
|||||||
Reference in New Issue
Block a user