Demo Fixes 7

This commit is contained in:
2025-03-30 03:09:57 -04:00
parent fdb92ff061
commit 284dd50972
2 changed files with 83 additions and 20 deletions

View File

@@ -25,6 +25,10 @@ import whisperx
from generator import load_csm_1b, Segment
from dataclasses import dataclass
# Add these imports at the top
import psutil
import gc
# Configure logging
logging.basicConfig(level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
@@ -68,13 +72,13 @@ def load_models():
socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 0})
logger.info("Loading CSM 1B model...")
# CSM 1B loading
try:
socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 10, 'message': 'Loading CSM voice model'})
models.generator = load_csm_1b(device=DEVICE)
logger.info("CSM 1B model loaded successfully")
socketio.emit('model_status', {'model': 'csm', 'status': 'loaded'})
progress = 33
socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': progress})
socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 33})
if DEVICE == "cuda":
torch.cuda.empty_cache()
except Exception as e:
@@ -83,8 +87,9 @@ def load_models():
logger.error(f"Error loading CSM 1B model: {str(e)}\n{error_details}")
socketio.emit('model_status', {'model': 'csm', 'status': 'error', 'message': str(e)})
logger.info("Loading Whisper ASR model...")
# Whisper loading
try:
socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 40, 'message': 'Loading speech recognition model'})
# Use regular Whisper instead of WhisperX to avoid compatibility issues
from transformers import WhisperProcessor, WhisperForConditionalGeneration
@@ -96,16 +101,16 @@ def load_models():
logger.info("Whisper ASR model loaded successfully")
socketio.emit('model_status', {'model': 'asr', 'status': 'loaded'})
progress = 66
socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': progress})
socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 66})
if DEVICE == "cuda":
torch.cuda.empty_cache()
except Exception as e:
logger.error(f"Error loading ASR model: {str(e)}")
socketio.emit('model_status', {'model': 'asr', 'status': 'error', 'message': str(e)})
logger.info("Loading Llama 3.2 model...")
# Llama loading
try:
socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 70, 'message': 'Loading language model'})
models.llm = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3.2-1B",
device_map=DEVICE,
@@ -123,8 +128,8 @@ def load_models():
logger.info("Llama 3.2 model loaded successfully")
socketio.emit('model_status', {'model': 'llm', 'status': 'loaded'})
progress = 100
socketio.emit('model_status', {'model': 'overall', 'status': 'loaded', 'progress': progress})
socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 100, 'message': 'All models loaded successfully'})
socketio.emit('model_status', {'model': 'overall', 'status': 'loaded'})
except Exception as e:
logger.error(f"Error loading Llama 3.2 model: {str(e)}")
socketio.emit('model_status', {'model': 'llm', 'status': 'error', 'message': str(e)})
@@ -184,6 +189,39 @@ def system_status():
}
})
# Add a new endpoint to check system resources
@app.route('/api/system_resources')
def system_resources():
# Get CPU usage
cpu_percent = psutil.cpu_percent(interval=0.1)
# Get memory usage
memory = psutil.virtual_memory()
memory_used_gb = memory.used / (1024 ** 3)
memory_total_gb = memory.total / (1024 ** 3)
memory_percent = memory.percent
# Get GPU memory if available
gpu_memory = {}
if torch.cuda.is_available():
for i in range(torch.cuda.device_count()):
gpu_memory[f"gpu_{i}"] = {
"allocated": torch.cuda.memory_allocated(i) / (1024 ** 3),
"reserved": torch.cuda.memory_reserved(i) / (1024 ** 3),
"max_allocated": torch.cuda.max_memory_allocated(i) / (1024 ** 3)
}
return jsonify({
"cpu_percent": cpu_percent,
"memory": {
"used_gb": memory_used_gb,
"total_gb": memory_total_gb,
"percent": memory_percent
},
"gpu_memory": gpu_memory,
"active_sessions": len(active_conversations)
})
# Socket event handlers
@socketio.on('connect')
def handle_connect(auth=None):
@@ -331,18 +369,33 @@ def process_audio_and_respond(session_id, data):
speech_array, sampling_rate = librosa.load(temp_path, sr=16000)
# Convert to required format
input_features = models.asr_processor(
processor_output = models.asr_processor(
speech_array,
sampling_rate=sampling_rate,
return_tensors="pt"
).input_features.to(DEVICE)
# Generate token ids
predicted_ids = models.asr_model.generate(
input_features,
language="en",
task="transcribe"
return_tensors="pt",
padding=True, # Add padding
return_attention_mask=True # Request attention mask
)
input_features = processor_output.input_features.to(DEVICE)
attention_mask = processor_output.get('attention_mask', None)
if attention_mask is not None:
attention_mask = attention_mask.to(DEVICE)
# Generate token ids with attention mask
predicted_ids = models.asr_model.generate(
input_features,
attention_mask=attention_mask,
language="en",
task="transcribe"
)
else:
# Fallback if attention mask is not available
predicted_ids = models.asr_model.generate(
input_features,
language="en",
task="transcribe"
)
# Decode the predicted ids to text
user_text = models.asr_processor.batch_decode(