diff --git a/Backend/server.py b/Backend/server.py
index 992e674..4cc4f91 100644
--- a/Backend/server.py
+++ b/Backend/server.py
@@ -8,6 +8,7 @@ import logging
import threading
import queue
import tempfile
+import gc
from typing import Dict, List, Optional, Tuple
import torch
@@ -18,6 +19,9 @@ from flask_socketio import SocketIO, emit
from flask_cors import CORS
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
+# Import WhisperX for better transcription
+import whisperx
+
from generator import load_csm_1b, Segment
from dataclasses import dataclass
@@ -52,7 +56,10 @@ class AppModels:
generator = None
tokenizer = None
llm = None
- asr = None
+ whisperx_model = None
+ whisperx_align_model = None
+ whisperx_align_metadata = None
+ diarize_model = None
models = AppModels()
@@ -69,22 +76,16 @@ def load_models():
logger.error(f"Error loading CSM 1B model: {str(e)}")
socketio.emit('model_status', {'model': 'csm', 'status': 'error', 'message': str(e)})
- logger.info("Loading ASR pipeline...")
+ logger.info("Loading WhisperX model...")
try:
- # Initialize the pipeline without the language parameter in the constructor
- models.asr = pipeline(
- "automatic-speech-recognition",
- model="openai/whisper-small",
- device=DEVICE
- )
-
- # Configure the model with the appropriate options
- # Note that for whisper, language should be set during inference, not initialization
- logger.info("ASR pipeline loaded successfully")
- socketio.emit('model_status', {'model': 'asr', 'status': 'loaded'})
+ # Use WhisperX instead of the regular Whisper
+ compute_type = "float16" if DEVICE == "cuda" else "float32"
+ models.whisperx_model = whisperx.load_model("large-v2", DEVICE, compute_type=compute_type)
+ logger.info("WhisperX model loaded successfully")
+ socketio.emit('model_status', {'model': 'whisperx', 'status': 'loaded'})
except Exception as e:
- logger.error(f"Error loading ASR pipeline: {str(e)}")
- socketio.emit('model_status', {'model': 'asr', 'status': 'error', 'message': str(e)})
+ logger.error(f"Error loading WhisperX model: {str(e)}")
+ socketio.emit('model_status', {'model': 'whisperx', 'status': 'error', 'message': str(e)})
logger.info("Loading Llama 3.2 model...")
try:
@@ -149,7 +150,7 @@ def system_status():
"device": DEVICE,
"models": {
"generator": models.generator is not None,
- "asr": models.asr is not None,
+ "whisperx": models.whisperx_model is not None,
"llm": models.llm is not None
}
})
@@ -259,8 +260,8 @@ def process_audio_queue(session_id, q):
del user_queues[session_id]
def process_audio_and_respond(session_id, data):
- """Process audio data and generate a response"""
- if models.generator is None or models.asr is None or models.llm is None:
+ """Process audio data and generate a response using WhisperX"""
+ if models.generator is None or models.whisperx_model is None or models.llm is None:
logger.warning("Models not yet loaded!")
with app.app_context():
socketio.emit('error', {'message': 'Models still loading, please wait'}, room=session_id)
@@ -292,7 +293,57 @@ def process_audio_and_respond(session_id, data):
temp_path = temp_file.name
try:
- # Load audio file
+ # Load audio using WhisperX
+ with app.app_context():
+ socketio.emit('processing_status', {'status': 'transcribing'}, room=session_id)
+
+ # Load audio with WhisperX instead of torchaudio
+ audio = whisperx.load_audio(temp_path)
+
+ # Transcribe using WhisperX
+ batch_size = 16 # Adjust based on available memory
+ result = models.whisperx_model.transcribe(audio, batch_size=batch_size)
+
+ # Get the detected language
+ language_code = result["language"]
+ logger.info(f"Detected language: {language_code}")
+
+ # Load alignment model if not already loaded
+ if models.whisperx_align_model is None or language_code != getattr(models, 'last_language', None):
+ # Clear previous models to save memory
+ if models.whisperx_align_model is not None:
+ del models.whisperx_align_model
+ del models.whisperx_align_metadata
+ gc.collect()
+ torch.cuda.empty_cache() if DEVICE == "cuda" else None
+
+ models.whisperx_align_model, models.whisperx_align_metadata = whisperx.load_align_model(
+ language_code=language_code, device=DEVICE
+ )
+ models.last_language = language_code
+
+ # Align the transcript
+ result = whisperx.align(
+ result["segments"],
+ models.whisperx_align_model,
+ models.whisperx_align_metadata,
+ audio,
+ DEVICE,
+ return_char_alignments=False
+ )
+
+ # Combine all segments into a single transcript
+ user_text = ' '.join([segment['text'] for segment in result['segments']])
+
+ # If no text was recognized, don't process further
+ if not user_text or len(user_text.strip()) == 0:
+ with app.app_context():
+ socketio.emit('error', {'message': 'No speech detected'}, room=session_id)
+ return
+
+ logger.info(f"Transcription: {user_text}")
+
+ # Load audio for CSM input
waveform, sample_rate = torchaudio.load(temp_path)
# Normalize to mono if needed
@@ -307,24 +358,6 @@ def process_audio_and_respond(session_id, data):
new_freq=models.generator.sample_rate
)
- # Transcribe audio
- with app.app_context():
- socketio.emit('processing_status', {'status': 'transcribing'}, room=session_id)
-
- # Use the ASR pipeline to transcribe
- transcription_result = models.asr(
- {"array": waveform.squeeze().cpu().numpy(), "sampling_rate": models.generator.sample_rate},
- return_timestamps=False,
- generate_kwargs={"language": "en"} # Set language during inference
- )
- user_text = transcription_result['text'].strip()
-
- # If no text was recognized, don't process further
- if not user_text:
- with app.app_context():
- socketio.emit('error', {'message': 'No speech detected'}, room=session_id)
- return
-
# Add the user's message to conversation history
user_segment = conversation.add_segment(
text=user_text,
@@ -336,7 +369,8 @@ def process_audio_and_respond(session_id, data):
with app.app_context():
socketio.emit('transcription', {
'text': user_text,
- 'speaker': speaker_id
+ 'speaker': speaker_id,
+ 'segments': result['segments'] # Send detailed segments info
}, room=session_id)
# Generate AI response using Llama
@@ -365,7 +399,7 @@ def process_audio_and_respond(session_id, data):
with torch.no_grad():
generated_ids = models.llm.generate(
input_ids,
- attention_mask=attention_mask, # Add the attention mask
+ attention_mask=attention_mask,
max_new_tokens=100,
temperature=0.7,
top_p=0.9,
@@ -390,13 +424,13 @@ def process_audio_and_respond(session_id, data):
'chunk_index': 0
}, room=session_id)
- # Define AI speaker ID (use a consistent value for the AI's voice)
- ai_speaker_id = 1 # Use speaker 1 for AI responses
+ # Define AI speaker ID
+ ai_speaker_id = conversation.ai_speaker_id
# Generate audio
audio_tensor = models.generator.generate(
text=response_text,
- speaker=ai_speaker_id, # Use the local variable instead of conversation.ai_speaker_id
+ speaker=ai_speaker_id,
context=conversation.get_context(),
max_audio_length_ms=10_000,
temperature=0.9
@@ -405,7 +439,7 @@ def process_audio_and_respond(session_id, data):
# Add AI response to conversation history
ai_segment = conversation.add_segment(
text=response_text,
- speaker=ai_speaker_id, # Also use the local variable here
+ speaker=ai_speaker_id,
audio=audio_tensor
)
@@ -444,6 +478,8 @@ def process_audio_and_respond(session_id, data):
except Exception as e:
logger.error(f"Error processing audio: {str(e)}")
+ import traceback
+ logger.error(traceback.format_exc())
with app.app_context():
socketio.emit('error', {'message': f'Error: {str(e)}'}, room=session_id)
finally:
diff --git a/Backend/voice-chat.js b/Backend/voice-chat.js
index 5c3f247..705d5ab 100644
--- a/Backend/voice-chat.js
+++ b/Backend/voice-chat.js
@@ -627,7 +627,59 @@ function addSystemMessage(text) {
// Handle transcription response from server
function handleTranscription(data) {
const speaker = data.speaker === 0 ? 'user' : 'ai';
- addMessage(data.text, speaker);
+
+ // Create the message div
+ const messageDiv = addMessage(data.text, speaker);
+
+ // If we have detailed segments from WhisperX, add timestamps
+ if (data.segments && data.segments.length > 0) {
+ // Add a timestamps container
+ const timestampsContainer = document.createElement('div');
+ timestampsContainer.className = 'timestamps-container';
+ timestampsContainer.style.display = 'none'; // Hidden by default
+
+ // Add a toggle button
+ const toggleButton = document.createElement('button');
+ toggleButton.className = 'timestamp-toggle';
+ toggleButton.textContent = 'Show Timestamps';
+ toggleButton.onclick = function() {
+ const isHidden = timestampsContainer.style.display === 'none';
+ timestampsContainer.style.display = isHidden ? 'block' : 'none';
+ toggleButton.textContent = isHidden ? 'Hide Timestamps' : 'Show Timestamps';
+ };
+
+ // Add timestamps for each segment
+ data.segments.forEach(segment => {
+ const timestampDiv = document.createElement('div');
+ timestampDiv.className = 'timestamp';
+
+ // Format start and end times
+ const startTime = formatTime(segment.start);
+ const endTime = formatTime(segment.end);
+
+ timestampDiv.innerHTML = `
+ [${startTime} - ${endTime}]
+ ${segment.text}
+ `;
+
+ timestampsContainer.appendChild(timestampDiv);
+ });
+
+ // Add the timestamp elements to the message
+ messageDiv.appendChild(toggleButton);
+ messageDiv.appendChild(timestampsContainer);
+ }
+
+ return messageDiv;
+}
+
+// Helper function to format time in seconds to MM:SS.ms format
+function formatTime(seconds) {
+ const mins = Math.floor(seconds / 60);
+ const secs = Math.floor(seconds % 60);
+ const ms = Math.floor((seconds % 1) * 1000);
+
+ return `${mins.toString().padStart(2, '0')}:${secs.toString().padStart(2, '0')}.${ms.toString().padStart(3, '0')}`;
}
// Handle context update from server
@@ -804,7 +856,7 @@ function finalizeStreamingAudio() {
// Add CSS styles for new UI elements
document.addEventListener('DOMContentLoaded', function() {
- // Add styles for processing state
+ // Add styles for processing state and timestamps
const style = document.createElement('style');
style.textContent = `
.message.processing {
@@ -833,6 +885,44 @@ document.addEventListener('DOMContentLoaded', function() {
0% { transform: rotate(0deg); }
100% { transform: rotate(360deg); }
}
+
+ /* Timestamp styles */
+ .timestamp-toggle {
+ font-size: 0.75em;
+ padding: 4px 8px;
+ margin-top: 8px;
+ background-color: #f0f0f0;
+ border: 1px solid #ddd;
+ border-radius: 4px;
+ cursor: pointer;
+ }
+
+ .timestamp-toggle:hover {
+ background-color: #e0e0e0;
+ }
+
+ .timestamps-container {
+ margin-top: 8px;
+ padding: 8px;
+ background-color: #f9f9f9;
+ border-radius: 4px;
+ font-size: 0.85em;
+ }
+
+ .timestamp {
+ margin-bottom: 4px;
+ padding: 2px 0;
+ }
+
+ .timestamp .time {
+ color: #666;
+ font-family: monospace;
+ margin-right: 8px;
+ }
+
+ .timestamp .text {
+ color: #333;
+ }
`;
document.head.appendChild(style);
});