Demo Fixes 5

This commit is contained in:
2025-03-30 02:57:57 -04:00
parent 1344115013
commit bfaffef684
3 changed files with 60 additions and 3 deletions

View File

@@ -301,6 +301,30 @@
color: #888; color: #888;
margin-top: 5px; margin-top: 5px;
} }
/* Add this to your existing styles */
.loading-progress {
width: 100%;
max-width: 150px;
margin-right: 10px;
}
progress {
width: 100%;
height: 8px;
border-radius: 4px;
overflow: hidden;
}
progress::-webkit-progress-bar {
background-color: #eee;
border-radius: 4px;
}
progress::-webkit-progress-value {
background-color: var(--primary-color);
border-radius: 4px;
}
</style> </style>
</head> </head>
<body> <body>
@@ -318,6 +342,11 @@
<span id="statusText">Disconnected</span> <span id="statusText">Disconnected</span>
</div> </div>
<!-- Add this above the model status indicators in the chat-header div -->
<div class="loading-progress">
<progress id="modelLoadingProgress" max="100" value="0">0%</progress>
</div>
<!-- Add this model status panel --> <!-- Add this model status panel -->
<div class="model-status"> <div class="model-status">
<div id="csmStatus" class="model-indicator loading" title="Loading CSM model...">CSM</div> <div id="csmStatus" class="model-indicator loading" title="Loading CSM model...">CSM</div>

View File

@@ -59,17 +59,28 @@ class AppModels:
asr_model = None asr_model = None
asr_processor = None asr_processor = None
# Initialize the models object
models = AppModels()
def load_models(): def load_models():
"""Load all required models""" """Load all required models"""
global models global models
socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 0})
logger.info("Loading CSM 1B model...") logger.info("Loading CSM 1B model...")
try: try:
models.generator = load_csm_1b(device=DEVICE) models.generator = load_csm_1b(device=DEVICE)
logger.info("CSM 1B model loaded successfully") logger.info("CSM 1B model loaded successfully")
socketio.emit('model_status', {'model': 'csm', 'status': 'loaded'}) socketio.emit('model_status', {'model': 'csm', 'status': 'loaded'})
progress = 33
socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': progress})
if DEVICE == "cuda":
torch.cuda.empty_cache()
except Exception as e: except Exception as e:
logger.error(f"Error loading CSM 1B model: {str(e)}") import traceback
error_details = traceback.format_exc()
logger.error(f"Error loading CSM 1B model: {str(e)}\n{error_details}")
socketio.emit('model_status', {'model': 'csm', 'status': 'error', 'message': str(e)}) socketio.emit('model_status', {'model': 'csm', 'status': 'error', 'message': str(e)})
logger.info("Loading Whisper ASR model...") logger.info("Loading Whisper ASR model...")
@@ -85,6 +96,10 @@ def load_models():
logger.info("Whisper ASR model loaded successfully") logger.info("Whisper ASR model loaded successfully")
socketio.emit('model_status', {'model': 'asr', 'status': 'loaded'}) socketio.emit('model_status', {'model': 'asr', 'status': 'loaded'})
progress = 66
socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': progress})
if DEVICE == "cuda":
torch.cuda.empty_cache()
except Exception as e: except Exception as e:
logger.error(f"Error loading ASR model: {str(e)}") logger.error(f"Error loading ASR model: {str(e)}")
socketio.emit('model_status', {'model': 'asr', 'status': 'error', 'message': str(e)}) socketio.emit('model_status', {'model': 'asr', 'status': 'error', 'message': str(e)})
@@ -99,6 +114,8 @@ def load_models():
models.tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B") models.tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B")
logger.info("Llama 3.2 model loaded successfully") logger.info("Llama 3.2 model loaded successfully")
socketio.emit('model_status', {'model': 'llm', 'status': 'loaded'}) socketio.emit('model_status', {'model': 'llm', 'status': 'loaded'})
progress = 100
socketio.emit('model_status', {'model': 'overall', 'status': 'loaded', 'progress': progress})
except Exception as e: except Exception as e:
logger.error(f"Error loading Llama 3.2 model: {str(e)}") logger.error(f"Error loading Llama 3.2 model: {str(e)}")
socketio.emit('model_status', {'model': 'llm', 'status': 'error', 'message': str(e)}) socketio.emit('model_status', {'model': 'llm', 'status': 'error', 'message': str(e)})

View File

@@ -909,9 +909,20 @@ function finalizeStreamingAudio() {
streamingAudio.audioElement = null; streamingAudio.audioElement = null;
} }
// Handle model status updates // Enhance the handleModelStatusUpdate function:
function handleModelStatusUpdate(data) { function handleModelStatusUpdate(data) {
const { model, status, message } = data; const { model, status, message, progress } = data;
if (model === 'overall' && status === 'loading') {
// Update overall loading progress
const progressBar = document.getElementById('modelLoadingProgress');
if (progressBar) {
progressBar.value = progress;
progressBar.textContent = `${progress}%`;
}
return;
}
if (status === 'loaded') { if (status === 'loaded') {
console.log(`Model ${model} loaded successfully`); console.log(`Model ${model} loaded successfully`);