Demo Update 20
This commit is contained in:
@@ -260,6 +260,47 @@
|
|||||||
font-size: 0.8em;
|
font-size: 0.8em;
|
||||||
color: #777;
|
color: #777;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* Model status indicators */
|
||||||
|
.model-status {
|
||||||
|
display: flex;
|
||||||
|
gap: 8px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.model-indicator {
|
||||||
|
padding: 3px 6px;
|
||||||
|
border-radius: 4px;
|
||||||
|
font-size: 0.7em;
|
||||||
|
font-weight: bold;
|
||||||
|
}
|
||||||
|
|
||||||
|
.model-indicator.loading {
|
||||||
|
background-color: #ffd54f;
|
||||||
|
color: #000;
|
||||||
|
}
|
||||||
|
|
||||||
|
.model-indicator.loaded {
|
||||||
|
background-color: #4CAF50;
|
||||||
|
color: white;
|
||||||
|
}
|
||||||
|
|
||||||
|
.model-indicator.error {
|
||||||
|
background-color: #f44336;
|
||||||
|
color: white;
|
||||||
|
}
|
||||||
|
|
||||||
|
.message-timestamp {
|
||||||
|
font-size: 0.7em;
|
||||||
|
color: #888;
|
||||||
|
margin-top: 4px;
|
||||||
|
text-align: right;
|
||||||
|
}
|
||||||
|
|
||||||
|
.simple-timestamp {
|
||||||
|
font-size: 0.8em;
|
||||||
|
color: #888;
|
||||||
|
margin-top: 5px;
|
||||||
|
}
|
||||||
</style>
|
</style>
|
||||||
</head>
|
</head>
|
||||||
<body>
|
<body>
|
||||||
@@ -276,6 +317,13 @@
|
|||||||
<div id="statusDot" class="status-dot"></div>
|
<div id="statusDot" class="status-dot"></div>
|
||||||
<span id="statusText">Disconnected</span>
|
<span id="statusText">Disconnected</span>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
<!-- Add this model status panel -->
|
||||||
|
<div class="model-status">
|
||||||
|
<div id="csmStatus" class="model-indicator loading" title="Loading CSM model...">CSM</div>
|
||||||
|
<div id="asrStatus" class="model-indicator loading" title="Loading ASR model...">ASR</div>
|
||||||
|
<div id="llmStatus" class="model-indicator loading" title="Loading LLM model...">LLM</div>
|
||||||
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div id="conversation" class="conversation"></div>
|
<div id="conversation" class="conversation"></div>
|
||||||
</div>
|
</div>
|
||||||
|
|||||||
@@ -1,7 +1,11 @@
|
|||||||
|
flask==2.2.5
|
||||||
|
flask-socketio==5.3.6
|
||||||
|
flask-cors==4.0.0
|
||||||
torch==2.4.0
|
torch==2.4.0
|
||||||
torchaudio==2.4.0
|
torchaudio==2.4.0
|
||||||
tokenizers==0.21.0
|
tokenizers==0.21.0
|
||||||
transformers==4.49.0
|
transformers==4.49.0
|
||||||
|
librosa==0.10.1
|
||||||
huggingface_hub==0.28.1
|
huggingface_hub==0.28.1
|
||||||
moshi==0.2.2
|
moshi==0.2.2
|
||||||
torchtune==0.4.0
|
torchtune==0.4.0
|
||||||
|
|||||||
@@ -56,12 +56,8 @@ class AppModels:
|
|||||||
generator = None
|
generator = None
|
||||||
tokenizer = None
|
tokenizer = None
|
||||||
llm = None
|
llm = None
|
||||||
whisperx_model = None
|
asr_model = None
|
||||||
whisperx_align_model = None
|
asr_processor = None
|
||||||
whisperx_align_metadata = None
|
|
||||||
diarize_model = None
|
|
||||||
|
|
||||||
models = AppModels()
|
|
||||||
|
|
||||||
def load_models():
|
def load_models():
|
||||||
"""Load all required models"""
|
"""Load all required models"""
|
||||||
@@ -76,16 +72,22 @@ def load_models():
|
|||||||
logger.error(f"Error loading CSM 1B model: {str(e)}")
|
logger.error(f"Error loading CSM 1B model: {str(e)}")
|
||||||
socketio.emit('model_status', {'model': 'csm', 'status': 'error', 'message': str(e)})
|
socketio.emit('model_status', {'model': 'csm', 'status': 'error', 'message': str(e)})
|
||||||
|
|
||||||
logger.info("Loading WhisperX model...")
|
logger.info("Loading Whisper ASR model...")
|
||||||
try:
|
try:
|
||||||
# Use WhisperX instead of the regular Whisper
|
# Use regular Whisper instead of WhisperX to avoid compatibility issues
|
||||||
compute_type = "float16" if DEVICE == "cuda" else "float32"
|
from transformers import WhisperProcessor, WhisperForConditionalGeneration
|
||||||
models.whisperx_model = whisperx.load_model("large-v2", DEVICE, compute_type=compute_type)
|
|
||||||
logger.info("WhisperX model loaded successfully")
|
# Use a smaller model for faster processing
|
||||||
socketio.emit('model_status', {'model': 'whisperx', 'status': 'loaded'})
|
model_id = "openai/whisper-small"
|
||||||
|
|
||||||
|
models.asr_processor = WhisperProcessor.from_pretrained(model_id)
|
||||||
|
models.asr_model = WhisperForConditionalGeneration.from_pretrained(model_id).to(DEVICE)
|
||||||
|
|
||||||
|
logger.info("Whisper ASR model loaded successfully")
|
||||||
|
socketio.emit('model_status', {'model': 'asr', 'status': 'loaded'})
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error loading WhisperX model: {str(e)}")
|
logger.error(f"Error loading ASR model: {str(e)}")
|
||||||
socketio.emit('model_status', {'model': 'whisperx', 'status': 'error', 'message': str(e)})
|
socketio.emit('model_status', {'model': 'asr', 'status': 'error', 'message': str(e)})
|
||||||
|
|
||||||
logger.info("Loading Llama 3.2 model...")
|
logger.info("Loading Llama 3.2 model...")
|
||||||
try:
|
try:
|
||||||
@@ -141,7 +143,8 @@ def health_check():
|
|||||||
"models_loaded": models.generator is not None and models.llm is not None
|
"models_loaded": models.generator is not None and models.llm is not None
|
||||||
})
|
})
|
||||||
|
|
||||||
# Add a system status endpoint
|
# Fix the system_status function:
|
||||||
|
|
||||||
@app.route('/api/status')
|
@app.route('/api/status')
|
||||||
def system_status():
|
def system_status():
|
||||||
return jsonify({
|
return jsonify({
|
||||||
@@ -150,7 +153,7 @@ def system_status():
|
|||||||
"device": DEVICE,
|
"device": DEVICE,
|
||||||
"models": {
|
"models": {
|
||||||
"generator": models.generator is not None,
|
"generator": models.generator is not None,
|
||||||
"whisperx": models.whisperx_model is not None,
|
"asr": models.asr_model is not None, # Use the correct model name
|
||||||
"llm": models.llm is not None
|
"llm": models.llm is not None
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
@@ -260,8 +263,8 @@ def process_audio_queue(session_id, q):
|
|||||||
del user_queues[session_id]
|
del user_queues[session_id]
|
||||||
|
|
||||||
def process_audio_and_respond(session_id, data):
|
def process_audio_and_respond(session_id, data):
|
||||||
"""Process audio data and generate a response using WhisperX"""
|
"""Process audio data and generate a response using standard Whisper"""
|
||||||
if models.generator is None or models.whisperx_model is None or models.llm is None:
|
if models.generator is None or models.asr_model is None or models.llm is None:
|
||||||
logger.warning("Models not yet loaded!")
|
logger.warning("Models not yet loaded!")
|
||||||
with app.app_context():
|
with app.app_context():
|
||||||
socketio.emit('error', {'message': 'Models still loading, please wait'}, room=session_id)
|
socketio.emit('error', {'message': 'Models still loading, please wait'}, room=session_id)
|
||||||
@@ -293,47 +296,33 @@ def process_audio_and_respond(session_id, data):
|
|||||||
temp_path = temp_file.name
|
temp_path = temp_file.name
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Load audio using WhisperX
|
# Notify client that transcription is starting
|
||||||
with app.app_context():
|
with app.app_context():
|
||||||
socketio.emit('processing_status', {'status': 'transcribing'}, room=session_id)
|
socketio.emit('processing_status', {'status': 'transcribing'}, room=session_id)
|
||||||
|
|
||||||
# Load audio with WhisperX instead of torchaudio
|
# Load audio for ASR processing
|
||||||
audio = whisperx.load_audio(temp_path)
|
import librosa
|
||||||
|
speech_array, sampling_rate = librosa.load(temp_path, sr=16000)
|
||||||
|
|
||||||
# Transcribe using WhisperX
|
# Convert to required format
|
||||||
batch_size = 16 # Adjust based on available memory
|
input_features = models.asr_processor(
|
||||||
result = models.whisperx_model.transcribe(audio, batch_size=batch_size)
|
speech_array,
|
||||||
|
sampling_rate=sampling_rate,
|
||||||
|
return_tensors="pt"
|
||||||
|
).input_features.to(DEVICE)
|
||||||
|
|
||||||
# Get the detected language
|
# Generate token ids
|
||||||
language_code = result["language"]
|
predicted_ids = models.asr_model.generate(
|
||||||
logger.info(f"Detected language: {language_code}")
|
input_features,
|
||||||
|
language="en",
|
||||||
# Load alignment model if not already loaded
|
task="transcribe"
|
||||||
if models.whisperx_align_model is None or language_code != getattr(models, 'last_language', None):
|
|
||||||
# Clear previous models to save memory
|
|
||||||
if models.whisperx_align_model is not None:
|
|
||||||
del models.whisperx_align_model
|
|
||||||
del models.whisperx_align_metadata
|
|
||||||
gc.collect()
|
|
||||||
torch.cuda.empty_cache() if DEVICE == "cuda" else None
|
|
||||||
|
|
||||||
models.whisperx_align_model, models.whisperx_align_metadata = whisperx.load_align_model(
|
|
||||||
language_code=language_code, device=DEVICE
|
|
||||||
)
|
|
||||||
models.last_language = language_code
|
|
||||||
|
|
||||||
# Align the transcript
|
|
||||||
result = whisperx.align(
|
|
||||||
result["segments"],
|
|
||||||
models.whisperx_align_model,
|
|
||||||
models.whisperx_align_metadata,
|
|
||||||
audio,
|
|
||||||
DEVICE,
|
|
||||||
return_char_alignments=False
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Combine all segments into a single transcript
|
# Decode the predicted ids to text
|
||||||
user_text = ' '.join([segment['text'] for segment in result['segments']])
|
user_text = models.asr_processor.batch_decode(
|
||||||
|
predicted_ids,
|
||||||
|
skip_special_tokens=True
|
||||||
|
)[0]
|
||||||
|
|
||||||
# If no text was recognized, don't process further
|
# If no text was recognized, don't process further
|
||||||
if not user_text or len(user_text.strip()) == 0:
|
if not user_text or len(user_text.strip()) == 0:
|
||||||
@@ -369,8 +358,7 @@ def process_audio_and_respond(session_id, data):
|
|||||||
with app.app_context():
|
with app.app_context():
|
||||||
socketio.emit('transcription', {
|
socketio.emit('transcription', {
|
||||||
'text': user_text,
|
'text': user_text,
|
||||||
'speaker': speaker_id,
|
'speaker': speaker_id
|
||||||
'segments': result['segments'] # Send detailed segments info
|
|
||||||
}, room=session_id)
|
}, room=session_id)
|
||||||
|
|
||||||
# Generate AI response using Llama
|
# Generate AI response using Llama
|
||||||
|
|||||||
@@ -105,8 +105,25 @@ function setupSocketConnection() {
|
|||||||
});
|
});
|
||||||
|
|
||||||
state.socket.on('error', (data) => {
|
state.socket.on('error', (data) => {
|
||||||
addSystemMessage(`Error: ${data.message}`);
|
|
||||||
console.error('Server error:', data.message);
|
console.error('Server error:', data.message);
|
||||||
|
|
||||||
|
// Make the error more user-friendly
|
||||||
|
let userMessage = data.message;
|
||||||
|
|
||||||
|
// Check for common errors and provide more helpful messages
|
||||||
|
if (data.message.includes('Models still loading')) {
|
||||||
|
userMessage = 'The AI models are still loading. Please wait a moment and try again.';
|
||||||
|
} else if (data.message.includes('No speech detected')) {
|
||||||
|
userMessage = 'No speech was detected. Please speak clearly and try again.';
|
||||||
|
}
|
||||||
|
|
||||||
|
addSystemMessage(`Error: ${userMessage}`);
|
||||||
|
|
||||||
|
// Reset button state if it was processing
|
||||||
|
if (elements.streamButton.classList.contains('processing')) {
|
||||||
|
elements.streamButton.classList.remove('processing');
|
||||||
|
elements.streamButton.innerHTML = '<i class="fas fa-microphone"></i> Start Conversation';
|
||||||
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
// Register message handlers
|
// Register message handlers
|
||||||
@@ -115,6 +132,9 @@ function setupSocketConnection() {
|
|||||||
state.socket.on('streaming_status', handleStreamingStatus);
|
state.socket.on('streaming_status', handleStreamingStatus);
|
||||||
state.socket.on('processing_status', handleProcessingStatus);
|
state.socket.on('processing_status', handleProcessingStatus);
|
||||||
|
|
||||||
|
// Add model status handlers
|
||||||
|
state.socket.on('model_status', handleModelStatusUpdate);
|
||||||
|
|
||||||
// Handlers for incremental audio streaming
|
// Handlers for incremental audio streaming
|
||||||
state.socket.on('audio_response_start', handleAudioResponseStart);
|
state.socket.on('audio_response_start', handleAudioResponseStart);
|
||||||
state.socket.on('audio_response_chunk', handleAudioResponseChunk);
|
state.socket.on('audio_response_chunk', handleAudioResponseChunk);
|
||||||
@@ -189,6 +209,27 @@ function startStreaming() {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check if models are loaded via the API
|
||||||
|
fetch('/api/status')
|
||||||
|
.then(response => response.json())
|
||||||
|
.then(data => {
|
||||||
|
if (!data.models.generator || !data.models.asr || !data.models.llm) {
|
||||||
|
addSystemMessage('Still loading AI models. Please wait...');
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Continue with recording if models are loaded
|
||||||
|
initializeRecording();
|
||||||
|
})
|
||||||
|
.catch(error => {
|
||||||
|
console.error('Error checking model status:', error);
|
||||||
|
// Try anyway, the server will respond with an error if models aren't ready
|
||||||
|
initializeRecording();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extracted the recording initialization to a separate function
|
||||||
|
function initializeRecording() {
|
||||||
// Request microphone access
|
// Request microphone access
|
||||||
navigator.mediaDevices.getUserMedia({ audio: true, video: false })
|
navigator.mediaDevices.getUserMedia({ audio: true, video: false })
|
||||||
.then(stream => {
|
.then(stream => {
|
||||||
@@ -600,6 +641,13 @@ function addMessage(text, type) {
|
|||||||
textElement.textContent = text;
|
textElement.textContent = text;
|
||||||
messageDiv.appendChild(textElement);
|
messageDiv.appendChild(textElement);
|
||||||
|
|
||||||
|
// Add timestamp to every message
|
||||||
|
const timestamp = new Date().toLocaleTimeString();
|
||||||
|
const timeLabel = document.createElement('div');
|
||||||
|
timeLabel.className = 'message-timestamp';
|
||||||
|
timeLabel.textContent = timestamp;
|
||||||
|
messageDiv.appendChild(timeLabel);
|
||||||
|
|
||||||
elements.conversation.appendChild(messageDiv);
|
elements.conversation.appendChild(messageDiv);
|
||||||
|
|
||||||
// Auto-scroll to the bottom
|
// Auto-scroll to the bottom
|
||||||
@@ -668,6 +716,13 @@ function handleTranscription(data) {
|
|||||||
// Add the timestamp elements to the message
|
// Add the timestamp elements to the message
|
||||||
messageDiv.appendChild(toggleButton);
|
messageDiv.appendChild(toggleButton);
|
||||||
messageDiv.appendChild(timestampsContainer);
|
messageDiv.appendChild(timestampsContainer);
|
||||||
|
} else {
|
||||||
|
// No timestamp data available - add a simple timestamp for the entire message
|
||||||
|
const timestamp = new Date().toLocaleTimeString();
|
||||||
|
const timeLabel = document.createElement('div');
|
||||||
|
timeLabel.className = 'simple-timestamp';
|
||||||
|
timeLabel.textContent = timestamp;
|
||||||
|
messageDiv.appendChild(timeLabel);
|
||||||
}
|
}
|
||||||
|
|
||||||
return messageDiv;
|
return messageDiv;
|
||||||
@@ -854,6 +909,52 @@ function finalizeStreamingAudio() {
|
|||||||
streamingAudio.audioElement = null;
|
streamingAudio.audioElement = null;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Handle model status updates
|
||||||
|
function handleModelStatusUpdate(data) {
|
||||||
|
const { model, status, message } = data;
|
||||||
|
|
||||||
|
if (status === 'loaded') {
|
||||||
|
console.log(`Model ${model} loaded successfully`);
|
||||||
|
addSystemMessage(`${model.toUpperCase()} model loaded successfully`);
|
||||||
|
|
||||||
|
// Update UI to show model is ready
|
||||||
|
const modelStatusElement = document.getElementById(`${model}Status`);
|
||||||
|
if (modelStatusElement) {
|
||||||
|
modelStatusElement.classList.remove('loading');
|
||||||
|
modelStatusElement.classList.add('loaded');
|
||||||
|
modelStatusElement.title = 'Model loaded successfully';
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if the required models are loaded to enable conversation
|
||||||
|
checkAllModelsLoaded();
|
||||||
|
} else if (status === 'error') {
|
||||||
|
console.error(`Error loading ${model} model: ${message}`);
|
||||||
|
addSystemMessage(`Error loading ${model.toUpperCase()} model: ${message}`);
|
||||||
|
|
||||||
|
// Update UI to show model loading failed
|
||||||
|
const modelStatusElement = document.getElementById(`${model}Status`);
|
||||||
|
if (modelStatusElement) {
|
||||||
|
modelStatusElement.classList.remove('loading');
|
||||||
|
modelStatusElement.classList.add('error');
|
||||||
|
modelStatusElement.title = `Error: ${message}`;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if all required models are loaded and enable UI accordingly
|
||||||
|
function checkAllModelsLoaded() {
|
||||||
|
// When all models are loaded, enable the stream button if it was disabled
|
||||||
|
const allLoaded =
|
||||||
|
document.getElementById('csmStatus')?.classList.contains('loaded') &&
|
||||||
|
document.getElementById('asrStatus')?.classList.contains('loaded') &&
|
||||||
|
document.getElementById('llmStatus')?.classList.contains('loaded');
|
||||||
|
|
||||||
|
if (allLoaded) {
|
||||||
|
elements.streamButton.disabled = false;
|
||||||
|
addSystemMessage('All models loaded. Ready for conversation!');
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Add CSS styles for new UI elements
|
// Add CSS styles for new UI elements
|
||||||
document.addEventListener('DOMContentLoaded', function() {
|
document.addEventListener('DOMContentLoaded', function() {
|
||||||
// Add styles for processing state and timestamps
|
// Add styles for processing state and timestamps
|
||||||
|
|||||||
Reference in New Issue
Block a user