diff --git a/Backend/index.html b/Backend/index.html
index 01bd5f7..e69ec9a 100644
--- a/Backend/index.html
+++ b/Backend/index.html
@@ -260,6 +260,47 @@
font-size: 0.8em;
color: #777;
}
+
+ /* Model status indicators */
+ .model-status {
+ display: flex;
+ gap: 8px;
+ }
+
+ .model-indicator {
+ padding: 3px 6px;
+ border-radius: 4px;
+ font-size: 0.7em;
+ font-weight: bold;
+ }
+
+ .model-indicator.loading {
+ background-color: #ffd54f;
+ color: #000;
+ }
+
+ .model-indicator.loaded {
+ background-color: #4CAF50;
+ color: white;
+ }
+
+ .model-indicator.error {
+ background-color: #f44336;
+ color: white;
+ }
+
+ .message-timestamp {
+ font-size: 0.7em;
+ color: #888;
+ margin-top: 4px;
+ text-align: right;
+ }
+
+ .simple-timestamp {
+ font-size: 0.8em;
+ color: #888;
+ margin-top: 5px;
+ }
@@ -276,6 +317,13 @@
Disconnected
+
+
+
diff --git a/Backend/requirements.txt b/Backend/requirements.txt
index ba8a04f..1e05eb3 100644
--- a/Backend/requirements.txt
+++ b/Backend/requirements.txt
@@ -1,7 +1,11 @@
+flask==2.2.5
+flask-socketio==5.3.6
+flask-cors==4.0.0
torch==2.4.0
torchaudio==2.4.0
tokenizers==0.21.0
transformers==4.49.0
+librosa==0.10.1
huggingface_hub==0.28.1
moshi==0.2.2
torchtune==0.4.0
diff --git a/Backend/server.py b/Backend/server.py
index 4cc4f91..ab56e77 100644
--- a/Backend/server.py
+++ b/Backend/server.py
@@ -56,12 +56,8 @@ class AppModels:
generator = None
tokenizer = None
llm = None
- whisperx_model = None
- whisperx_align_model = None
- whisperx_align_metadata = None
- diarize_model = None
-
-models = AppModels()
+ asr_model = None
+ asr_processor = None
def load_models():
"""Load all required models"""
@@ -76,16 +72,22 @@ def load_models():
logger.error(f"Error loading CSM 1B model: {str(e)}")
socketio.emit('model_status', {'model': 'csm', 'status': 'error', 'message': str(e)})
- logger.info("Loading WhisperX model...")
+ logger.info("Loading Whisper ASR model...")
try:
- # Use WhisperX instead of the regular Whisper
- compute_type = "float16" if DEVICE == "cuda" else "float32"
- models.whisperx_model = whisperx.load_model("large-v2", DEVICE, compute_type=compute_type)
- logger.info("WhisperX model loaded successfully")
- socketio.emit('model_status', {'model': 'whisperx', 'status': 'loaded'})
+ # Use regular Whisper instead of WhisperX to avoid compatibility issues
+ from transformers import WhisperProcessor, WhisperForConditionalGeneration
+
+ # Use a smaller model for faster processing
+ model_id = "openai/whisper-small"
+
+ models.asr_processor = WhisperProcessor.from_pretrained(model_id)
+ models.asr_model = WhisperForConditionalGeneration.from_pretrained(model_id).to(DEVICE)
+
+ logger.info("Whisper ASR model loaded successfully")
+ socketio.emit('model_status', {'model': 'asr', 'status': 'loaded'})
except Exception as e:
- logger.error(f"Error loading WhisperX model: {str(e)}")
- socketio.emit('model_status', {'model': 'whisperx', 'status': 'error', 'message': str(e)})
+ logger.error(f"Error loading ASR model: {str(e)}")
+ socketio.emit('model_status', {'model': 'asr', 'status': 'error', 'message': str(e)})
logger.info("Loading Llama 3.2 model...")
try:
@@ -141,7 +143,8 @@ def health_check():
"models_loaded": models.generator is not None and models.llm is not None
})
-# Add a system status endpoint
+# Fix the system_status function:
+
@app.route('/api/status')
def system_status():
return jsonify({
@@ -150,7 +153,7 @@ def system_status():
"device": DEVICE,
"models": {
"generator": models.generator is not None,
- "whisperx": models.whisperx_model is not None,
+ "asr": models.asr_model is not None, # Use the correct model name
"llm": models.llm is not None
}
})
@@ -260,8 +263,8 @@ def process_audio_queue(session_id, q):
del user_queues[session_id]
def process_audio_and_respond(session_id, data):
- """Process audio data and generate a response using WhisperX"""
- if models.generator is None or models.whisperx_model is None or models.llm is None:
+ """Process audio data and generate a response using standard Whisper"""
+ if models.generator is None or models.asr_model is None or models.llm is None:
logger.warning("Models not yet loaded!")
with app.app_context():
socketio.emit('error', {'message': 'Models still loading, please wait'}, room=session_id)
@@ -293,47 +296,33 @@ def process_audio_and_respond(session_id, data):
temp_path = temp_file.name
try:
- # Load audio using WhisperX
+ # Notify client that transcription is starting
with app.app_context():
socketio.emit('processing_status', {'status': 'transcribing'}, room=session_id)
- # Load audio with WhisperX instead of torchaudio
- audio = whisperx.load_audio(temp_path)
+ # Load audio for ASR processing
+ import librosa
+ speech_array, sampling_rate = librosa.load(temp_path, sr=16000)
- # Transcribe using WhisperX
- batch_size = 16 # Adjust based on available memory
- result = models.whisperx_model.transcribe(audio, batch_size=batch_size)
+ # Convert to required format
+ input_features = models.asr_processor(
+ speech_array,
+ sampling_rate=sampling_rate,
+ return_tensors="pt"
+ ).input_features.to(DEVICE)
- # Get the detected language
- language_code = result["language"]
- logger.info(f"Detected language: {language_code}")
-
- # Load alignment model if not already loaded
- if models.whisperx_align_model is None or language_code != getattr(models, 'last_language', None):
- # Clear previous models to save memory
- if models.whisperx_align_model is not None:
- del models.whisperx_align_model
- del models.whisperx_align_metadata
- gc.collect()
- torch.cuda.empty_cache() if DEVICE == "cuda" else None
-
- models.whisperx_align_model, models.whisperx_align_metadata = whisperx.load_align_model(
- language_code=language_code, device=DEVICE
- )
- models.last_language = language_code
-
- # Align the transcript
- result = whisperx.align(
- result["segments"],
- models.whisperx_align_model,
- models.whisperx_align_metadata,
- audio,
- DEVICE,
- return_char_alignments=False
+ # Generate token ids
+ predicted_ids = models.asr_model.generate(
+ input_features,
+ language="en",
+ task="transcribe"
)
- # Combine all segments into a single transcript
- user_text = ' '.join([segment['text'] for segment in result['segments']])
+ # Decode the predicted ids to text
+ user_text = models.asr_processor.batch_decode(
+ predicted_ids,
+ skip_special_tokens=True
+ )[0]
# If no text was recognized, don't process further
if not user_text or len(user_text.strip()) == 0:
@@ -369,8 +358,7 @@ def process_audio_and_respond(session_id, data):
with app.app_context():
socketio.emit('transcription', {
'text': user_text,
- 'speaker': speaker_id,
- 'segments': result['segments'] # Send detailed segments info
+ 'speaker': speaker_id
}, room=session_id)
# Generate AI response using Llama
diff --git a/Backend/voice-chat.js b/Backend/voice-chat.js
index 705d5ab..e4f1272 100644
--- a/Backend/voice-chat.js
+++ b/Backend/voice-chat.js
@@ -105,8 +105,25 @@ function setupSocketConnection() {
});
state.socket.on('error', (data) => {
- addSystemMessage(`Error: ${data.message}`);
console.error('Server error:', data.message);
+
+ // Make the error more user-friendly
+ let userMessage = data.message;
+
+ // Check for common errors and provide more helpful messages
+ if (data.message.includes('Models still loading')) {
+ userMessage = 'The AI models are still loading. Please wait a moment and try again.';
+ } else if (data.message.includes('No speech detected')) {
+ userMessage = 'No speech was detected. Please speak clearly and try again.';
+ }
+
+ addSystemMessage(`Error: ${userMessage}`);
+
+ // Reset button state if it was processing
+ if (elements.streamButton.classList.contains('processing')) {
+ elements.streamButton.classList.remove('processing');
+ elements.streamButton.innerHTML = ' Start Conversation';
+ }
});
// Register message handlers
@@ -115,6 +132,9 @@ function setupSocketConnection() {
state.socket.on('streaming_status', handleStreamingStatus);
state.socket.on('processing_status', handleProcessingStatus);
+ // Add model status handlers
+ state.socket.on('model_status', handleModelStatusUpdate);
+
// Handlers for incremental audio streaming
state.socket.on('audio_response_start', handleAudioResponseStart);
state.socket.on('audio_response_chunk', handleAudioResponseChunk);
@@ -189,6 +209,27 @@ function startStreaming() {
return;
}
+ // Check if models are loaded via the API
+ fetch('/api/status')
+ .then(response => response.json())
+ .then(data => {
+ if (!data.models.generator || !data.models.asr || !data.models.llm) {
+ addSystemMessage('Still loading AI models. Please wait...');
+ return;
+ }
+
+ // Continue with recording if models are loaded
+ initializeRecording();
+ })
+ .catch(error => {
+ console.error('Error checking model status:', error);
+ // Try anyway, the server will respond with an error if models aren't ready
+ initializeRecording();
+ });
+}
+
+// Extracted the recording initialization to a separate function
+function initializeRecording() {
// Request microphone access
navigator.mediaDevices.getUserMedia({ audio: true, video: false })
.then(stream => {
@@ -600,6 +641,13 @@ function addMessage(text, type) {
textElement.textContent = text;
messageDiv.appendChild(textElement);
+ // Add timestamp to every message
+ const timestamp = new Date().toLocaleTimeString();
+ const timeLabel = document.createElement('div');
+ timeLabel.className = 'message-timestamp';
+ timeLabel.textContent = timestamp;
+ messageDiv.appendChild(timeLabel);
+
elements.conversation.appendChild(messageDiv);
// Auto-scroll to the bottom
@@ -668,6 +716,13 @@ function handleTranscription(data) {
// Add the timestamp elements to the message
messageDiv.appendChild(toggleButton);
messageDiv.appendChild(timestampsContainer);
+ } else {
+ // No timestamp data available - add a simple timestamp for the entire message
+ const timestamp = new Date().toLocaleTimeString();
+ const timeLabel = document.createElement('div');
+ timeLabel.className = 'simple-timestamp';
+ timeLabel.textContent = timestamp;
+ messageDiv.appendChild(timeLabel);
}
return messageDiv;
@@ -854,6 +909,52 @@ function finalizeStreamingAudio() {
streamingAudio.audioElement = null;
}
+// Handle model status updates
+function handleModelStatusUpdate(data) {
+ const { model, status, message } = data;
+
+ if (status === 'loaded') {
+ console.log(`Model ${model} loaded successfully`);
+ addSystemMessage(`${model.toUpperCase()} model loaded successfully`);
+
+ // Update UI to show model is ready
+ const modelStatusElement = document.getElementById(`${model}Status`);
+ if (modelStatusElement) {
+ modelStatusElement.classList.remove('loading');
+ modelStatusElement.classList.add('loaded');
+ modelStatusElement.title = 'Model loaded successfully';
+ }
+
+ // Check if the required models are loaded to enable conversation
+ checkAllModelsLoaded();
+ } else if (status === 'error') {
+ console.error(`Error loading ${model} model: ${message}`);
+ addSystemMessage(`Error loading ${model.toUpperCase()} model: ${message}`);
+
+ // Update UI to show model loading failed
+ const modelStatusElement = document.getElementById(`${model}Status`);
+ if (modelStatusElement) {
+ modelStatusElement.classList.remove('loading');
+ modelStatusElement.classList.add('error');
+ modelStatusElement.title = `Error: ${message}`;
+ }
+ }
+}
+
+// Check if all required models are loaded and enable UI accordingly
+function checkAllModelsLoaded() {
+ // When all models are loaded, enable the stream button if it was disabled
+ const allLoaded =
+ document.getElementById('csmStatus')?.classList.contains('loaded') &&
+ document.getElementById('asrStatus')?.classList.contains('loaded') &&
+ document.getElementById('llmStatus')?.classList.contains('loaded');
+
+ if (allLoaded) {
+ elements.streamButton.disabled = false;
+ addSystemMessage('All models loaded. Ready for conversation!');
+ }
+}
+
// Add CSS styles for new UI elements
document.addEventListener('DOMContentLoaded', function() {
// Add styles for processing state and timestamps