Demo Update 20

This commit is contained in:
2025-03-30 02:53:30 -04:00
parent 10902f1d71
commit 4fb2c9bc52
4 changed files with 196 additions and 55 deletions

View File

@@ -260,6 +260,47 @@
font-size: 0.8em; font-size: 0.8em;
color: #777; color: #777;
} }
/* Model status indicators */
.model-status {
display: flex;
gap: 8px;
}
.model-indicator {
padding: 3px 6px;
border-radius: 4px;
font-size: 0.7em;
font-weight: bold;
}
.model-indicator.loading {
background-color: #ffd54f;
color: #000;
}
.model-indicator.loaded {
background-color: #4CAF50;
color: white;
}
.model-indicator.error {
background-color: #f44336;
color: white;
}
.message-timestamp {
font-size: 0.7em;
color: #888;
margin-top: 4px;
text-align: right;
}
.simple-timestamp {
font-size: 0.8em;
color: #888;
margin-top: 5px;
}
</style> </style>
</head> </head>
<body> <body>
@@ -276,6 +317,13 @@
<div id="statusDot" class="status-dot"></div> <div id="statusDot" class="status-dot"></div>
<span id="statusText">Disconnected</span> <span id="statusText">Disconnected</span>
</div> </div>
<!-- Add this model status panel -->
<div class="model-status">
<div id="csmStatus" class="model-indicator loading" title="Loading CSM model...">CSM</div>
<div id="asrStatus" class="model-indicator loading" title="Loading ASR model...">ASR</div>
<div id="llmStatus" class="model-indicator loading" title="Loading LLM model...">LLM</div>
</div>
</div> </div>
<div id="conversation" class="conversation"></div> <div id="conversation" class="conversation"></div>
</div> </div>

View File

@@ -1,7 +1,11 @@
flask==2.2.5
flask-socketio==5.3.6
flask-cors==4.0.0
torch==2.4.0 torch==2.4.0
torchaudio==2.4.0 torchaudio==2.4.0
tokenizers==0.21.0 tokenizers==0.21.0
transformers==4.49.0 transformers==4.49.0
librosa==0.10.1
huggingface_hub==0.28.1 huggingface_hub==0.28.1
moshi==0.2.2 moshi==0.2.2
torchtune==0.4.0 torchtune==0.4.0

View File

@@ -56,12 +56,8 @@ class AppModels:
generator = None generator = None
tokenizer = None tokenizer = None
llm = None llm = None
whisperx_model = None asr_model = None
whisperx_align_model = None asr_processor = None
whisperx_align_metadata = None
diarize_model = None
models = AppModels()
def load_models(): def load_models():
"""Load all required models""" """Load all required models"""
@@ -76,16 +72,22 @@ def load_models():
logger.error(f"Error loading CSM 1B model: {str(e)}") logger.error(f"Error loading CSM 1B model: {str(e)}")
socketio.emit('model_status', {'model': 'csm', 'status': 'error', 'message': str(e)}) socketio.emit('model_status', {'model': 'csm', 'status': 'error', 'message': str(e)})
logger.info("Loading WhisperX model...") logger.info("Loading Whisper ASR model...")
try: try:
# Use WhisperX instead of the regular Whisper # Use regular Whisper instead of WhisperX to avoid compatibility issues
compute_type = "float16" if DEVICE == "cuda" else "float32" from transformers import WhisperProcessor, WhisperForConditionalGeneration
models.whisperx_model = whisperx.load_model("large-v2", DEVICE, compute_type=compute_type)
logger.info("WhisperX model loaded successfully") # Use a smaller model for faster processing
socketio.emit('model_status', {'model': 'whisperx', 'status': 'loaded'}) model_id = "openai/whisper-small"
models.asr_processor = WhisperProcessor.from_pretrained(model_id)
models.asr_model = WhisperForConditionalGeneration.from_pretrained(model_id).to(DEVICE)
logger.info("Whisper ASR model loaded successfully")
socketio.emit('model_status', {'model': 'asr', 'status': 'loaded'})
except Exception as e: except Exception as e:
logger.error(f"Error loading WhisperX model: {str(e)}") logger.error(f"Error loading ASR model: {str(e)}")
socketio.emit('model_status', {'model': 'whisperx', 'status': 'error', 'message': str(e)}) socketio.emit('model_status', {'model': 'asr', 'status': 'error', 'message': str(e)})
logger.info("Loading Llama 3.2 model...") logger.info("Loading Llama 3.2 model...")
try: try:
@@ -141,7 +143,8 @@ def health_check():
"models_loaded": models.generator is not None and models.llm is not None "models_loaded": models.generator is not None and models.llm is not None
}) })
# Add a system status endpoint # Fix the system_status function:
@app.route('/api/status') @app.route('/api/status')
def system_status(): def system_status():
return jsonify({ return jsonify({
@@ -150,7 +153,7 @@ def system_status():
"device": DEVICE, "device": DEVICE,
"models": { "models": {
"generator": models.generator is not None, "generator": models.generator is not None,
"whisperx": models.whisperx_model is not None, "asr": models.asr_model is not None, # Use the correct model name
"llm": models.llm is not None "llm": models.llm is not None
} }
}) })
@@ -260,8 +263,8 @@ def process_audio_queue(session_id, q):
del user_queues[session_id] del user_queues[session_id]
def process_audio_and_respond(session_id, data): def process_audio_and_respond(session_id, data):
"""Process audio data and generate a response using WhisperX""" """Process audio data and generate a response using standard Whisper"""
if models.generator is None or models.whisperx_model is None or models.llm is None: if models.generator is None or models.asr_model is None or models.llm is None:
logger.warning("Models not yet loaded!") logger.warning("Models not yet loaded!")
with app.app_context(): with app.app_context():
socketio.emit('error', {'message': 'Models still loading, please wait'}, room=session_id) socketio.emit('error', {'message': 'Models still loading, please wait'}, room=session_id)
@@ -293,47 +296,33 @@ def process_audio_and_respond(session_id, data):
temp_path = temp_file.name temp_path = temp_file.name
try: try:
# Load audio using WhisperX # Notify client that transcription is starting
with app.app_context(): with app.app_context():
socketio.emit('processing_status', {'status': 'transcribing'}, room=session_id) socketio.emit('processing_status', {'status': 'transcribing'}, room=session_id)
# Load audio with WhisperX instead of torchaudio # Load audio for ASR processing
audio = whisperx.load_audio(temp_path) import librosa
speech_array, sampling_rate = librosa.load(temp_path, sr=16000)
# Transcribe using WhisperX # Convert to required format
batch_size = 16 # Adjust based on available memory input_features = models.asr_processor(
result = models.whisperx_model.transcribe(audio, batch_size=batch_size) speech_array,
sampling_rate=sampling_rate,
return_tensors="pt"
).input_features.to(DEVICE)
# Get the detected language # Generate token ids
language_code = result["language"] predicted_ids = models.asr_model.generate(
logger.info(f"Detected language: {language_code}") input_features,
language="en",
# Load alignment model if not already loaded task="transcribe"
if models.whisperx_align_model is None or language_code != getattr(models, 'last_language', None):
# Clear previous models to save memory
if models.whisperx_align_model is not None:
del models.whisperx_align_model
del models.whisperx_align_metadata
gc.collect()
torch.cuda.empty_cache() if DEVICE == "cuda" else None
models.whisperx_align_model, models.whisperx_align_metadata = whisperx.load_align_model(
language_code=language_code, device=DEVICE
)
models.last_language = language_code
# Align the transcript
result = whisperx.align(
result["segments"],
models.whisperx_align_model,
models.whisperx_align_metadata,
audio,
DEVICE,
return_char_alignments=False
) )
# Combine all segments into a single transcript # Decode the predicted ids to text
user_text = ' '.join([segment['text'] for segment in result['segments']]) user_text = models.asr_processor.batch_decode(
predicted_ids,
skip_special_tokens=True
)[0]
# If no text was recognized, don't process further # If no text was recognized, don't process further
if not user_text or len(user_text.strip()) == 0: if not user_text or len(user_text.strip()) == 0:
@@ -369,8 +358,7 @@ def process_audio_and_respond(session_id, data):
with app.app_context(): with app.app_context():
socketio.emit('transcription', { socketio.emit('transcription', {
'text': user_text, 'text': user_text,
'speaker': speaker_id, 'speaker': speaker_id
'segments': result['segments'] # Send detailed segments info
}, room=session_id) }, room=session_id)
# Generate AI response using Llama # Generate AI response using Llama

View File

@@ -105,8 +105,25 @@ function setupSocketConnection() {
}); });
state.socket.on('error', (data) => { state.socket.on('error', (data) => {
addSystemMessage(`Error: ${data.message}`);
console.error('Server error:', data.message); console.error('Server error:', data.message);
// Make the error more user-friendly
let userMessage = data.message;
// Check for common errors and provide more helpful messages
if (data.message.includes('Models still loading')) {
userMessage = 'The AI models are still loading. Please wait a moment and try again.';
} else if (data.message.includes('No speech detected')) {
userMessage = 'No speech was detected. Please speak clearly and try again.';
}
addSystemMessage(`Error: ${userMessage}`);
// Reset button state if it was processing
if (elements.streamButton.classList.contains('processing')) {
elements.streamButton.classList.remove('processing');
elements.streamButton.innerHTML = '<i class="fas fa-microphone"></i> Start Conversation';
}
}); });
// Register message handlers // Register message handlers
@@ -115,6 +132,9 @@ function setupSocketConnection() {
state.socket.on('streaming_status', handleStreamingStatus); state.socket.on('streaming_status', handleStreamingStatus);
state.socket.on('processing_status', handleProcessingStatus); state.socket.on('processing_status', handleProcessingStatus);
// Add model status handlers
state.socket.on('model_status', handleModelStatusUpdate);
// Handlers for incremental audio streaming // Handlers for incremental audio streaming
state.socket.on('audio_response_start', handleAudioResponseStart); state.socket.on('audio_response_start', handleAudioResponseStart);
state.socket.on('audio_response_chunk', handleAudioResponseChunk); state.socket.on('audio_response_chunk', handleAudioResponseChunk);
@@ -189,6 +209,27 @@ function startStreaming() {
return; return;
} }
// Check if models are loaded via the API
fetch('/api/status')
.then(response => response.json())
.then(data => {
if (!data.models.generator || !data.models.asr || !data.models.llm) {
addSystemMessage('Still loading AI models. Please wait...');
return;
}
// Continue with recording if models are loaded
initializeRecording();
})
.catch(error => {
console.error('Error checking model status:', error);
// Try anyway, the server will respond with an error if models aren't ready
initializeRecording();
});
}
// Extracted the recording initialization to a separate function
function initializeRecording() {
// Request microphone access // Request microphone access
navigator.mediaDevices.getUserMedia({ audio: true, video: false }) navigator.mediaDevices.getUserMedia({ audio: true, video: false })
.then(stream => { .then(stream => {
@@ -600,6 +641,13 @@ function addMessage(text, type) {
textElement.textContent = text; textElement.textContent = text;
messageDiv.appendChild(textElement); messageDiv.appendChild(textElement);
// Add timestamp to every message
const timestamp = new Date().toLocaleTimeString();
const timeLabel = document.createElement('div');
timeLabel.className = 'message-timestamp';
timeLabel.textContent = timestamp;
messageDiv.appendChild(timeLabel);
elements.conversation.appendChild(messageDiv); elements.conversation.appendChild(messageDiv);
// Auto-scroll to the bottom // Auto-scroll to the bottom
@@ -668,6 +716,13 @@ function handleTranscription(data) {
// Add the timestamp elements to the message // Add the timestamp elements to the message
messageDiv.appendChild(toggleButton); messageDiv.appendChild(toggleButton);
messageDiv.appendChild(timestampsContainer); messageDiv.appendChild(timestampsContainer);
} else {
// No timestamp data available - add a simple timestamp for the entire message
const timestamp = new Date().toLocaleTimeString();
const timeLabel = document.createElement('div');
timeLabel.className = 'simple-timestamp';
timeLabel.textContent = timestamp;
messageDiv.appendChild(timeLabel);
} }
return messageDiv; return messageDiv;
@@ -854,6 +909,52 @@ function finalizeStreamingAudio() {
streamingAudio.audioElement = null; streamingAudio.audioElement = null;
} }
// Handle model status updates
function handleModelStatusUpdate(data) {
const { model, status, message } = data;
if (status === 'loaded') {
console.log(`Model ${model} loaded successfully`);
addSystemMessage(`${model.toUpperCase()} model loaded successfully`);
// Update UI to show model is ready
const modelStatusElement = document.getElementById(`${model}Status`);
if (modelStatusElement) {
modelStatusElement.classList.remove('loading');
modelStatusElement.classList.add('loaded');
modelStatusElement.title = 'Model loaded successfully';
}
// Check if the required models are loaded to enable conversation
checkAllModelsLoaded();
} else if (status === 'error') {
console.error(`Error loading ${model} model: ${message}`);
addSystemMessage(`Error loading ${model.toUpperCase()} model: ${message}`);
// Update UI to show model loading failed
const modelStatusElement = document.getElementById(`${model}Status`);
if (modelStatusElement) {
modelStatusElement.classList.remove('loading');
modelStatusElement.classList.add('error');
modelStatusElement.title = `Error: ${message}`;
}
}
}
// Check if all required models are loaded and enable UI accordingly
function checkAllModelsLoaded() {
// When all models are loaded, enable the stream button if it was disabled
const allLoaded =
document.getElementById('csmStatus')?.classList.contains('loaded') &&
document.getElementById('asrStatus')?.classList.contains('loaded') &&
document.getElementById('llmStatus')?.classList.contains('loaded');
if (allLoaded) {
elements.streamButton.disabled = false;
addSystemMessage('All models loaded. Ready for conversation!');
}
}
// Add CSS styles for new UI elements // Add CSS styles for new UI elements
document.addEventListener('DOMContentLoaded', function() { document.addEventListener('DOMContentLoaded', function() {
// Add styles for processing state and timestamps // Add styles for processing state and timestamps