This commit is contained in:
BGV
2025-03-30 03:31:36 -04:00
2 changed files with 183 additions and 70 deletions

View File

@@ -25,6 +25,10 @@ import whisperx
from generator import load_csm_1b, Segment from generator import load_csm_1b, Segment
from dataclasses import dataclass from dataclasses import dataclass
# Add these imports at the top
import psutil
import gc
# Configure logging # Configure logging
logging.basicConfig(level=logging.INFO, logging.basicConfig(level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
@@ -56,8 +60,11 @@ class AppModels:
generator = None generator = None
tokenizer = None tokenizer = None
llm = None llm = None
asr_model = None whisperx_model = None
asr_processor = None whisperx_align_model = None
whisperx_align_metadata = None
diarize_model = None
last_language = None
# Initialize the models object # Initialize the models object
models = AppModels() models = AppModels()
@@ -68,13 +75,13 @@ def load_models():
socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 0}) socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 0})
logger.info("Loading CSM 1B model...") # CSM 1B loading
try: try:
socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 10, 'message': 'Loading CSM voice model'})
models.generator = load_csm_1b(device=DEVICE) models.generator = load_csm_1b(device=DEVICE)
logger.info("CSM 1B model loaded successfully") logger.info("CSM 1B model loaded successfully")
socketio.emit('model_status', {'model': 'csm', 'status': 'loaded'}) socketio.emit('model_status', {'model': 'csm', 'status': 'loaded'})
progress = 33 socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 33})
socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': progress})
if DEVICE == "cuda": if DEVICE == "cuda":
torch.cuda.empty_cache() torch.cuda.empty_cache()
except Exception as e: except Exception as e:
@@ -83,39 +90,51 @@ def load_models():
logger.error(f"Error loading CSM 1B model: {str(e)}\n{error_details}") logger.error(f"Error loading CSM 1B model: {str(e)}\n{error_details}")
socketio.emit('model_status', {'model': 'csm', 'status': 'error', 'message': str(e)}) socketio.emit('model_status', {'model': 'csm', 'status': 'error', 'message': str(e)})
logger.info("Loading Whisper ASR model...") # WhisperX loading
try: try:
# Use regular Whisper instead of WhisperX to avoid compatibility issues socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 40, 'message': 'Loading speech recognition model'})
from transformers import WhisperProcessor, WhisperForConditionalGeneration # Use WhisperX for better transcription with timestamps
import whisperx
# Use a smaller model for faster processing # Use compute_type based on device
model_id = "openai/whisper-small" compute_type = "float16" if DEVICE == "cuda" else "float32"
models.asr_processor = WhisperProcessor.from_pretrained(model_id) # Load the WhisperX model (smaller model for faster processing)
models.asr_model = WhisperForConditionalGeneration.from_pretrained(model_id).to(DEVICE) models.whisperx_model = whisperx.load_model("small", DEVICE, compute_type=compute_type)
logger.info("Whisper ASR model loaded successfully") logger.info("WhisperX model loaded successfully")
socketio.emit('model_status', {'model': 'asr', 'status': 'loaded'}) socketio.emit('model_status', {'model': 'asr', 'status': 'loaded'})
progress = 66 socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 66})
socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': progress})
if DEVICE == "cuda": if DEVICE == "cuda":
torch.cuda.empty_cache() torch.cuda.empty_cache()
except Exception as e: except Exception as e:
logger.error(f"Error loading ASR model: {str(e)}") import traceback
error_details = traceback.format_exc()
logger.error(f"Error loading WhisperX model: {str(e)}\n{error_details}")
socketio.emit('model_status', {'model': 'asr', 'status': 'error', 'message': str(e)}) socketio.emit('model_status', {'model': 'asr', 'status': 'error', 'message': str(e)})
logger.info("Loading Llama 3.2 model...") # Llama loading
try: try:
socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 70, 'message': 'Loading language model'})
models.llm = AutoModelForCausalLM.from_pretrained( models.llm = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3.2-1B", "meta-llama/Llama-3.2-1B",
device_map=DEVICE, device_map=DEVICE,
torch_dtype=torch.bfloat16 torch_dtype=torch.bfloat16
) )
models.tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B") models.tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B")
# Configure all special tokens
models.tokenizer.pad_token = models.tokenizer.eos_token
models.tokenizer.padding_side = "left" # For causal language modeling
# Inform the model about the pad token
if hasattr(models.llm.config, "pad_token_id") and models.llm.config.pad_token_id is None:
models.llm.config.pad_token_id = models.tokenizer.pad_token_id
logger.info("Llama 3.2 model loaded successfully") logger.info("Llama 3.2 model loaded successfully")
socketio.emit('model_status', {'model': 'llm', 'status': 'loaded'}) socketio.emit('model_status', {'model': 'llm', 'status': 'loaded'})
progress = 100 socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 100, 'message': 'All models loaded successfully'})
socketio.emit('model_status', {'model': 'overall', 'status': 'loaded', 'progress': progress}) socketio.emit('model_status', {'model': 'overall', 'status': 'loaded'})
except Exception as e: except Exception as e:
logger.error(f"Error loading Llama 3.2 model: {str(e)}") logger.error(f"Error loading Llama 3.2 model: {str(e)}")
socketio.emit('model_status', {'model': 'llm', 'status': 'error', 'message': str(e)}) socketio.emit('model_status', {'model': 'llm', 'status': 'error', 'message': str(e)})
@@ -170,11 +189,44 @@ def system_status():
"device": DEVICE, "device": DEVICE,
"models": { "models": {
"generator": models.generator is not None, "generator": models.generator is not None,
"asr": models.asr_model is not None, # Use the correct model name "asr": models.whisperx_model is not None, # Use the correct model name
"llm": models.llm is not None "llm": models.llm is not None
} }
}) })
# Add a new endpoint to check system resources
@app.route('/api/system_resources')
def system_resources():
# Get CPU usage
cpu_percent = psutil.cpu_percent(interval=0.1)
# Get memory usage
memory = psutil.virtual_memory()
memory_used_gb = memory.used / (1024 ** 3)
memory_total_gb = memory.total / (1024 ** 3)
memory_percent = memory.percent
# Get GPU memory if available
gpu_memory = {}
if torch.cuda.is_available():
for i in range(torch.cuda.device_count()):
gpu_memory[f"gpu_{i}"] = {
"allocated": torch.cuda.memory_allocated(i) / (1024 ** 3),
"reserved": torch.cuda.memory_reserved(i) / (1024 ** 3),
"max_allocated": torch.cuda.max_memory_allocated(i) / (1024 ** 3)
}
return jsonify({
"cpu_percent": cpu_percent,
"memory": {
"used_gb": memory_used_gb,
"total_gb": memory_total_gb,
"percent": memory_percent
},
"gpu_memory": gpu_memory,
"active_sessions": len(active_conversations)
})
# Socket event handlers # Socket event handlers
@socketio.on('connect') @socketio.on('connect')
def handle_connect(auth=None): def handle_connect(auth=None):
@@ -280,8 +332,8 @@ def process_audio_queue(session_id, q):
del user_queues[session_id] del user_queues[session_id]
def process_audio_and_respond(session_id, data): def process_audio_and_respond(session_id, data):
"""Process audio data and generate a response using standard Whisper""" """Process audio data and generate a response using WhisperX"""
if models.generator is None or models.asr_model is None or models.llm is None: if models.generator is None or models.whisperx_model is None or models.llm is None:
logger.warning("Models not yet loaded!") logger.warning("Models not yet loaded!")
with app.app_context(): with app.app_context():
socketio.emit('error', {'message': 'Models still loading, please wait'}, room=session_id) socketio.emit('error', {'message': 'Models still loading, please wait'}, room=session_id)
@@ -317,29 +369,69 @@ def process_audio_and_respond(session_id, data):
with app.app_context(): with app.app_context():
socketio.emit('processing_status', {'status': 'transcribing'}, room=session_id) socketio.emit('processing_status', {'status': 'transcribing'}, room=session_id)
# Load audio for ASR processing # Load audio using WhisperX
import librosa import whisperx
speech_array, sampling_rate = librosa.load(temp_path, sr=16000) audio = whisperx.load_audio(temp_path)
# Convert to required format # Check audio length and add a warning for short clips
input_features = models.asr_processor( audio_length = len(audio) / 16000 # assuming 16kHz sample rate
speech_array, if audio_length < 1.0:
sampling_rate=sampling_rate, logger.warning(f"Audio is very short ({audio_length:.2f}s), may affect transcription quality")
return_tensors="pt"
).input_features.to(DEVICE)
# Generate token ids # Transcribe using WhisperX
predicted_ids = models.asr_model.generate( batch_size = 16 # adjust based on your GPU memory
input_features, logger.info("Running WhisperX transcription...")
language="en",
task="transcribe" # Handle the warning about audio being shorter than 30s by suppressing it
import warnings
with warnings.catch_warnings():
warnings.filterwarnings("ignore", message="audio is shorter than 30s")
result = models.whisperx_model.transcribe(audio, batch_size=batch_size)
# Get the detected language
language_code = result["language"]
logger.info(f"Detected language: {language_code}")
# Check if alignment model needs to be loaded or updated
if models.whisperx_align_model is None or language_code != models.last_language:
# Clean up old models if they exist
if models.whisperx_align_model is not None:
del models.whisperx_align_model
del models.whisperx_align_metadata
if DEVICE == "cuda":
gc.collect()
torch.cuda.empty_cache()
# Load new alignment model for the detected language
logger.info(f"Loading alignment model for language: {language_code}")
models.whisperx_align_model, models.whisperx_align_metadata = whisperx.load_align_model(
language_code=language_code, device=DEVICE
)
models.last_language = language_code
# Align the transcript to get word-level timestamps
if result["segments"] and len(result["segments"]) > 0:
logger.info("Aligning transcript...")
result = whisperx.align(
result["segments"],
models.whisperx_align_model,
models.whisperx_align_metadata,
audio,
DEVICE,
return_char_alignments=False
) )
# Decode the predicted ids to text # Process the segments for better output
user_text = models.asr_processor.batch_decode( for segment in result["segments"]:
predicted_ids, # Round timestamps for better display
skip_special_tokens=True segment["start"] = round(segment["start"], 2)
)[0] segment["end"] = round(segment["end"], 2)
# Add a confidence score if not present
if "confidence" not in segment:
segment["confidence"] = 1.0 # Default confidence
# Extract the full text from all segments
user_text = ' '.join([segment['text'] for segment in result['segments']])
# If no text was recognized, don't process further # If no text was recognized, don't process further
if not user_text or len(user_text.strip()) == 0: if not user_text or len(user_text.strip()) == 0:
@@ -371,11 +463,12 @@ def process_audio_and_respond(session_id, data):
audio=waveform.squeeze() audio=waveform.squeeze()
) )
# Send transcription to client # Send transcription to client with detailed segments
with app.app_context(): with app.app_context():
socketio.emit('transcription', { socketio.emit('transcription', {
'text': user_text, 'text': user_text,
'speaker': speaker_id 'speaker': speaker_id,
'segments': result['segments'] # Include the detailed segments with timestamps
}, room=session_id) }, room=session_id)
# Generate AI response using Llama # Generate AI response using Llama
@@ -392,6 +485,11 @@ def process_audio_and_respond(session_id, data):
prompt = f"{conversation_history}Assistant: " prompt = f"{conversation_history}Assistant: "
# Generate response with Llama # Generate response with Llama
try:
# Ensure pad token is set
if models.tokenizer.pad_token is None:
models.tokenizer.pad_token = models.tokenizer.eos_token
input_tokens = models.tokenizer( input_tokens = models.tokenizer(
prompt, prompt,
return_tensors="pt", return_tensors="pt",
@@ -417,6 +515,11 @@ def process_audio_and_respond(session_id, data):
generated_ids[0][input_ids.shape[1]:], generated_ids[0][input_ids.shape[1]:],
skip_special_tokens=True skip_special_tokens=True
).strip() ).strip()
except Exception as e:
logger.error(f"Error generating response: {str(e)}")
import traceback
logger.error(traceback.format_exc())
response_text = "I'm sorry, I encountered an error while processing your request."
# Synthesize speech # Synthesize speech
with app.app_context(): with app.app_context():

View File

@@ -43,7 +43,9 @@ const state = {
volumeUpdateInterval: null, volumeUpdateInterval: null,
visualizerAnimationFrame: null, visualizerAnimationFrame: null,
currentSpeaker: 0, currentSpeaker: 0,
aiSpeakerId: 1 // Define the AI's speaker ID to match server.py aiSpeakerId: 1, // Define the AI's speaker ID to match server.py
transcriptionRetries: 0,
maxTranscriptionRetries: 3
}; };
// Visualizer variables // Visualizer variables
@@ -429,7 +431,15 @@ function handleSpeechState(isSilent) {
if (!hasAudioContent) { if (!hasAudioContent) {
console.warn('Audio buffer appears to be empty or very quiet'); console.warn('Audio buffer appears to be empty or very quiet');
addSystemMessage('No speech detected. Please try again and speak clearly.');
if (state.transcriptionRetries < state.maxTranscriptionRetries) {
state.transcriptionRetries++;
const retryMessage = `No speech detected (attempt ${state.transcriptionRetries}/${state.maxTranscriptionRetries}). Please speak louder and try again.`;
addSystemMessage(retryMessage);
} else {
state.transcriptionRetries = 0;
addSystemMessage('Multiple attempts failed to detect speech. Please check your microphone and try again.');
}
return; return;
} }