Merge branch 'main' of https://github.com/GamerBoss101/HooHacks-12
This commit is contained in:
@@ -25,6 +25,10 @@ import whisperx
|
|||||||
from generator import load_csm_1b, Segment
|
from generator import load_csm_1b, Segment
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
# Add these imports at the top
|
||||||
|
import psutil
|
||||||
|
import gc
|
||||||
|
|
||||||
# Configure logging
|
# Configure logging
|
||||||
logging.basicConfig(level=logging.INFO,
|
logging.basicConfig(level=logging.INFO,
|
||||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||||
@@ -56,8 +60,11 @@ class AppModels:
|
|||||||
generator = None
|
generator = None
|
||||||
tokenizer = None
|
tokenizer = None
|
||||||
llm = None
|
llm = None
|
||||||
asr_model = None
|
whisperx_model = None
|
||||||
asr_processor = None
|
whisperx_align_model = None
|
||||||
|
whisperx_align_metadata = None
|
||||||
|
diarize_model = None
|
||||||
|
last_language = None
|
||||||
|
|
||||||
# Initialize the models object
|
# Initialize the models object
|
||||||
models = AppModels()
|
models = AppModels()
|
||||||
@@ -68,13 +75,13 @@ def load_models():
|
|||||||
|
|
||||||
socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 0})
|
socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 0})
|
||||||
|
|
||||||
logger.info("Loading CSM 1B model...")
|
# CSM 1B loading
|
||||||
try:
|
try:
|
||||||
|
socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 10, 'message': 'Loading CSM voice model'})
|
||||||
models.generator = load_csm_1b(device=DEVICE)
|
models.generator = load_csm_1b(device=DEVICE)
|
||||||
logger.info("CSM 1B model loaded successfully")
|
logger.info("CSM 1B model loaded successfully")
|
||||||
socketio.emit('model_status', {'model': 'csm', 'status': 'loaded'})
|
socketio.emit('model_status', {'model': 'csm', 'status': 'loaded'})
|
||||||
progress = 33
|
socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 33})
|
||||||
socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': progress})
|
|
||||||
if DEVICE == "cuda":
|
if DEVICE == "cuda":
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -83,39 +90,51 @@ def load_models():
|
|||||||
logger.error(f"Error loading CSM 1B model: {str(e)}\n{error_details}")
|
logger.error(f"Error loading CSM 1B model: {str(e)}\n{error_details}")
|
||||||
socketio.emit('model_status', {'model': 'csm', 'status': 'error', 'message': str(e)})
|
socketio.emit('model_status', {'model': 'csm', 'status': 'error', 'message': str(e)})
|
||||||
|
|
||||||
logger.info("Loading Whisper ASR model...")
|
# WhisperX loading
|
||||||
try:
|
try:
|
||||||
# Use regular Whisper instead of WhisperX to avoid compatibility issues
|
socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 40, 'message': 'Loading speech recognition model'})
|
||||||
from transformers import WhisperProcessor, WhisperForConditionalGeneration
|
# Use WhisperX for better transcription with timestamps
|
||||||
|
import whisperx
|
||||||
|
|
||||||
# Use a smaller model for faster processing
|
# Use compute_type based on device
|
||||||
model_id = "openai/whisper-small"
|
compute_type = "float16" if DEVICE == "cuda" else "float32"
|
||||||
|
|
||||||
models.asr_processor = WhisperProcessor.from_pretrained(model_id)
|
# Load the WhisperX model (smaller model for faster processing)
|
||||||
models.asr_model = WhisperForConditionalGeneration.from_pretrained(model_id).to(DEVICE)
|
models.whisperx_model = whisperx.load_model("small", DEVICE, compute_type=compute_type)
|
||||||
|
|
||||||
logger.info("Whisper ASR model loaded successfully")
|
logger.info("WhisperX model loaded successfully")
|
||||||
socketio.emit('model_status', {'model': 'asr', 'status': 'loaded'})
|
socketio.emit('model_status', {'model': 'asr', 'status': 'loaded'})
|
||||||
progress = 66
|
socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 66})
|
||||||
socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': progress})
|
|
||||||
if DEVICE == "cuda":
|
if DEVICE == "cuda":
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error loading ASR model: {str(e)}")
|
import traceback
|
||||||
|
error_details = traceback.format_exc()
|
||||||
|
logger.error(f"Error loading WhisperX model: {str(e)}\n{error_details}")
|
||||||
socketio.emit('model_status', {'model': 'asr', 'status': 'error', 'message': str(e)})
|
socketio.emit('model_status', {'model': 'asr', 'status': 'error', 'message': str(e)})
|
||||||
|
|
||||||
logger.info("Loading Llama 3.2 model...")
|
# Llama loading
|
||||||
try:
|
try:
|
||||||
|
socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 70, 'message': 'Loading language model'})
|
||||||
models.llm = AutoModelForCausalLM.from_pretrained(
|
models.llm = AutoModelForCausalLM.from_pretrained(
|
||||||
"meta-llama/Llama-3.2-1B",
|
"meta-llama/Llama-3.2-1B",
|
||||||
device_map=DEVICE,
|
device_map=DEVICE,
|
||||||
torch_dtype=torch.bfloat16
|
torch_dtype=torch.bfloat16
|
||||||
)
|
)
|
||||||
models.tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B")
|
models.tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B")
|
||||||
|
|
||||||
|
# Configure all special tokens
|
||||||
|
models.tokenizer.pad_token = models.tokenizer.eos_token
|
||||||
|
models.tokenizer.padding_side = "left" # For causal language modeling
|
||||||
|
|
||||||
|
# Inform the model about the pad token
|
||||||
|
if hasattr(models.llm.config, "pad_token_id") and models.llm.config.pad_token_id is None:
|
||||||
|
models.llm.config.pad_token_id = models.tokenizer.pad_token_id
|
||||||
|
|
||||||
logger.info("Llama 3.2 model loaded successfully")
|
logger.info("Llama 3.2 model loaded successfully")
|
||||||
socketio.emit('model_status', {'model': 'llm', 'status': 'loaded'})
|
socketio.emit('model_status', {'model': 'llm', 'status': 'loaded'})
|
||||||
progress = 100
|
socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 100, 'message': 'All models loaded successfully'})
|
||||||
socketio.emit('model_status', {'model': 'overall', 'status': 'loaded', 'progress': progress})
|
socketio.emit('model_status', {'model': 'overall', 'status': 'loaded'})
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error loading Llama 3.2 model: {str(e)}")
|
logger.error(f"Error loading Llama 3.2 model: {str(e)}")
|
||||||
socketio.emit('model_status', {'model': 'llm', 'status': 'error', 'message': str(e)})
|
socketio.emit('model_status', {'model': 'llm', 'status': 'error', 'message': str(e)})
|
||||||
@@ -170,11 +189,44 @@ def system_status():
|
|||||||
"device": DEVICE,
|
"device": DEVICE,
|
||||||
"models": {
|
"models": {
|
||||||
"generator": models.generator is not None,
|
"generator": models.generator is not None,
|
||||||
"asr": models.asr_model is not None, # Use the correct model name
|
"asr": models.whisperx_model is not None, # Use the correct model name
|
||||||
"llm": models.llm is not None
|
"llm": models.llm is not None
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
# Add a new endpoint to check system resources
|
||||||
|
@app.route('/api/system_resources')
|
||||||
|
def system_resources():
|
||||||
|
# Get CPU usage
|
||||||
|
cpu_percent = psutil.cpu_percent(interval=0.1)
|
||||||
|
|
||||||
|
# Get memory usage
|
||||||
|
memory = psutil.virtual_memory()
|
||||||
|
memory_used_gb = memory.used / (1024 ** 3)
|
||||||
|
memory_total_gb = memory.total / (1024 ** 3)
|
||||||
|
memory_percent = memory.percent
|
||||||
|
|
||||||
|
# Get GPU memory if available
|
||||||
|
gpu_memory = {}
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
for i in range(torch.cuda.device_count()):
|
||||||
|
gpu_memory[f"gpu_{i}"] = {
|
||||||
|
"allocated": torch.cuda.memory_allocated(i) / (1024 ** 3),
|
||||||
|
"reserved": torch.cuda.memory_reserved(i) / (1024 ** 3),
|
||||||
|
"max_allocated": torch.cuda.max_memory_allocated(i) / (1024 ** 3)
|
||||||
|
}
|
||||||
|
|
||||||
|
return jsonify({
|
||||||
|
"cpu_percent": cpu_percent,
|
||||||
|
"memory": {
|
||||||
|
"used_gb": memory_used_gb,
|
||||||
|
"total_gb": memory_total_gb,
|
||||||
|
"percent": memory_percent
|
||||||
|
},
|
||||||
|
"gpu_memory": gpu_memory,
|
||||||
|
"active_sessions": len(active_conversations)
|
||||||
|
})
|
||||||
|
|
||||||
# Socket event handlers
|
# Socket event handlers
|
||||||
@socketio.on('connect')
|
@socketio.on('connect')
|
||||||
def handle_connect(auth=None):
|
def handle_connect(auth=None):
|
||||||
@@ -280,8 +332,8 @@ def process_audio_queue(session_id, q):
|
|||||||
del user_queues[session_id]
|
del user_queues[session_id]
|
||||||
|
|
||||||
def process_audio_and_respond(session_id, data):
|
def process_audio_and_respond(session_id, data):
|
||||||
"""Process audio data and generate a response using standard Whisper"""
|
"""Process audio data and generate a response using WhisperX"""
|
||||||
if models.generator is None or models.asr_model is None or models.llm is None:
|
if models.generator is None or models.whisperx_model is None or models.llm is None:
|
||||||
logger.warning("Models not yet loaded!")
|
logger.warning("Models not yet loaded!")
|
||||||
with app.app_context():
|
with app.app_context():
|
||||||
socketio.emit('error', {'message': 'Models still loading, please wait'}, room=session_id)
|
socketio.emit('error', {'message': 'Models still loading, please wait'}, room=session_id)
|
||||||
@@ -317,29 +369,69 @@ def process_audio_and_respond(session_id, data):
|
|||||||
with app.app_context():
|
with app.app_context():
|
||||||
socketio.emit('processing_status', {'status': 'transcribing'}, room=session_id)
|
socketio.emit('processing_status', {'status': 'transcribing'}, room=session_id)
|
||||||
|
|
||||||
# Load audio for ASR processing
|
# Load audio using WhisperX
|
||||||
import librosa
|
import whisperx
|
||||||
speech_array, sampling_rate = librosa.load(temp_path, sr=16000)
|
audio = whisperx.load_audio(temp_path)
|
||||||
|
|
||||||
# Convert to required format
|
# Check audio length and add a warning for short clips
|
||||||
input_features = models.asr_processor(
|
audio_length = len(audio) / 16000 # assuming 16kHz sample rate
|
||||||
speech_array,
|
if audio_length < 1.0:
|
||||||
sampling_rate=sampling_rate,
|
logger.warning(f"Audio is very short ({audio_length:.2f}s), may affect transcription quality")
|
||||||
return_tensors="pt"
|
|
||||||
).input_features.to(DEVICE)
|
|
||||||
|
|
||||||
# Generate token ids
|
# Transcribe using WhisperX
|
||||||
predicted_ids = models.asr_model.generate(
|
batch_size = 16 # adjust based on your GPU memory
|
||||||
input_features,
|
logger.info("Running WhisperX transcription...")
|
||||||
language="en",
|
|
||||||
task="transcribe"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Decode the predicted ids to text
|
# Handle the warning about audio being shorter than 30s by suppressing it
|
||||||
user_text = models.asr_processor.batch_decode(
|
import warnings
|
||||||
predicted_ids,
|
with warnings.catch_warnings():
|
||||||
skip_special_tokens=True
|
warnings.filterwarnings("ignore", message="audio is shorter than 30s")
|
||||||
)[0]
|
result = models.whisperx_model.transcribe(audio, batch_size=batch_size)
|
||||||
|
|
||||||
|
# Get the detected language
|
||||||
|
language_code = result["language"]
|
||||||
|
logger.info(f"Detected language: {language_code}")
|
||||||
|
|
||||||
|
# Check if alignment model needs to be loaded or updated
|
||||||
|
if models.whisperx_align_model is None or language_code != models.last_language:
|
||||||
|
# Clean up old models if they exist
|
||||||
|
if models.whisperx_align_model is not None:
|
||||||
|
del models.whisperx_align_model
|
||||||
|
del models.whisperx_align_metadata
|
||||||
|
if DEVICE == "cuda":
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
# Load new alignment model for the detected language
|
||||||
|
logger.info(f"Loading alignment model for language: {language_code}")
|
||||||
|
models.whisperx_align_model, models.whisperx_align_metadata = whisperx.load_align_model(
|
||||||
|
language_code=language_code, device=DEVICE
|
||||||
|
)
|
||||||
|
models.last_language = language_code
|
||||||
|
|
||||||
|
# Align the transcript to get word-level timestamps
|
||||||
|
if result["segments"] and len(result["segments"]) > 0:
|
||||||
|
logger.info("Aligning transcript...")
|
||||||
|
result = whisperx.align(
|
||||||
|
result["segments"],
|
||||||
|
models.whisperx_align_model,
|
||||||
|
models.whisperx_align_metadata,
|
||||||
|
audio,
|
||||||
|
DEVICE,
|
||||||
|
return_char_alignments=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# Process the segments for better output
|
||||||
|
for segment in result["segments"]:
|
||||||
|
# Round timestamps for better display
|
||||||
|
segment["start"] = round(segment["start"], 2)
|
||||||
|
segment["end"] = round(segment["end"], 2)
|
||||||
|
# Add a confidence score if not present
|
||||||
|
if "confidence" not in segment:
|
||||||
|
segment["confidence"] = 1.0 # Default confidence
|
||||||
|
|
||||||
|
# Extract the full text from all segments
|
||||||
|
user_text = ' '.join([segment['text'] for segment in result['segments']])
|
||||||
|
|
||||||
# If no text was recognized, don't process further
|
# If no text was recognized, don't process further
|
||||||
if not user_text or len(user_text.strip()) == 0:
|
if not user_text or len(user_text.strip()) == 0:
|
||||||
@@ -371,11 +463,12 @@ def process_audio_and_respond(session_id, data):
|
|||||||
audio=waveform.squeeze()
|
audio=waveform.squeeze()
|
||||||
)
|
)
|
||||||
|
|
||||||
# Send transcription to client
|
# Send transcription to client with detailed segments
|
||||||
with app.app_context():
|
with app.app_context():
|
||||||
socketio.emit('transcription', {
|
socketio.emit('transcription', {
|
||||||
'text': user_text,
|
'text': user_text,
|
||||||
'speaker': speaker_id
|
'speaker': speaker_id,
|
||||||
|
'segments': result['segments'] # Include the detailed segments with timestamps
|
||||||
}, room=session_id)
|
}, room=session_id)
|
||||||
|
|
||||||
# Generate AI response using Llama
|
# Generate AI response using Llama
|
||||||
@@ -392,31 +485,41 @@ def process_audio_and_respond(session_id, data):
|
|||||||
prompt = f"{conversation_history}Assistant: "
|
prompt = f"{conversation_history}Assistant: "
|
||||||
|
|
||||||
# Generate response with Llama
|
# Generate response with Llama
|
||||||
input_tokens = models.tokenizer(
|
try:
|
||||||
prompt,
|
# Ensure pad token is set
|
||||||
return_tensors="pt",
|
if models.tokenizer.pad_token is None:
|
||||||
padding=True,
|
models.tokenizer.pad_token = models.tokenizer.eos_token
|
||||||
return_attention_mask=True
|
|
||||||
)
|
|
||||||
input_ids = input_tokens.input_ids.to(DEVICE)
|
|
||||||
attention_mask = input_tokens.attention_mask.to(DEVICE)
|
|
||||||
|
|
||||||
with torch.no_grad():
|
input_tokens = models.tokenizer(
|
||||||
generated_ids = models.llm.generate(
|
prompt,
|
||||||
input_ids,
|
return_tensors="pt",
|
||||||
attention_mask=attention_mask,
|
padding=True,
|
||||||
max_new_tokens=100,
|
return_attention_mask=True
|
||||||
temperature=0.7,
|
|
||||||
top_p=0.9,
|
|
||||||
do_sample=True,
|
|
||||||
pad_token_id=models.tokenizer.eos_token_id
|
|
||||||
)
|
)
|
||||||
|
input_ids = input_tokens.input_ids.to(DEVICE)
|
||||||
|
attention_mask = input_tokens.attention_mask.to(DEVICE)
|
||||||
|
|
||||||
# Decode the response
|
with torch.no_grad():
|
||||||
response_text = models.tokenizer.decode(
|
generated_ids = models.llm.generate(
|
||||||
generated_ids[0][input_ids.shape[1]:],
|
input_ids,
|
||||||
skip_special_tokens=True
|
attention_mask=attention_mask,
|
||||||
).strip()
|
max_new_tokens=100,
|
||||||
|
temperature=0.7,
|
||||||
|
top_p=0.9,
|
||||||
|
do_sample=True,
|
||||||
|
pad_token_id=models.tokenizer.eos_token_id
|
||||||
|
)
|
||||||
|
|
||||||
|
# Decode the response
|
||||||
|
response_text = models.tokenizer.decode(
|
||||||
|
generated_ids[0][input_ids.shape[1]:],
|
||||||
|
skip_special_tokens=True
|
||||||
|
).strip()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error generating response: {str(e)}")
|
||||||
|
import traceback
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
response_text = "I'm sorry, I encountered an error while processing your request."
|
||||||
|
|
||||||
# Synthesize speech
|
# Synthesize speech
|
||||||
with app.app_context():
|
with app.app_context():
|
||||||
|
|||||||
@@ -43,7 +43,9 @@ const state = {
|
|||||||
volumeUpdateInterval: null,
|
volumeUpdateInterval: null,
|
||||||
visualizerAnimationFrame: null,
|
visualizerAnimationFrame: null,
|
||||||
currentSpeaker: 0,
|
currentSpeaker: 0,
|
||||||
aiSpeakerId: 1 // Define the AI's speaker ID to match server.py
|
aiSpeakerId: 1, // Define the AI's speaker ID to match server.py
|
||||||
|
transcriptionRetries: 0,
|
||||||
|
maxTranscriptionRetries: 3
|
||||||
};
|
};
|
||||||
|
|
||||||
// Visualizer variables
|
// Visualizer variables
|
||||||
@@ -429,7 +431,15 @@ function handleSpeechState(isSilent) {
|
|||||||
|
|
||||||
if (!hasAudioContent) {
|
if (!hasAudioContent) {
|
||||||
console.warn('Audio buffer appears to be empty or very quiet');
|
console.warn('Audio buffer appears to be empty or very quiet');
|
||||||
addSystemMessage('No speech detected. Please try again and speak clearly.');
|
|
||||||
|
if (state.transcriptionRetries < state.maxTranscriptionRetries) {
|
||||||
|
state.transcriptionRetries++;
|
||||||
|
const retryMessage = `No speech detected (attempt ${state.transcriptionRetries}/${state.maxTranscriptionRetries}). Please speak louder and try again.`;
|
||||||
|
addSystemMessage(retryMessage);
|
||||||
|
} else {
|
||||||
|
state.transcriptionRetries = 0;
|
||||||
|
addSystemMessage('Multiple attempts failed to detect speech. Please check your microphone and try again.');
|
||||||
|
}
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user