This commit is contained in:
BGV
2025-03-30 02:36:16 -04:00
2 changed files with 105 additions and 37 deletions

View File

@@ -61,22 +61,44 @@ def load_models():
global models global models
logger.info("Loading CSM 1B model...") logger.info("Loading CSM 1B model...")
models.generator = load_csm_1b(device=DEVICE) try:
models.generator = load_csm_1b(device=DEVICE)
logger.info("CSM 1B model loaded successfully")
socketio.emit('model_status', {'model': 'csm', 'status': 'loaded'})
except Exception as e:
logger.error(f"Error loading CSM 1B model: {str(e)}")
socketio.emit('model_status', {'model': 'csm', 'status': 'error', 'message': str(e)})
logger.info("Loading ASR pipeline...") logger.info("Loading ASR pipeline...")
models.asr = pipeline( try:
"automatic-speech-recognition", # Initialize the pipeline without the language parameter in the constructor
model="openai/whisper-small", models.asr = pipeline(
device=DEVICE "automatic-speech-recognition",
) model="openai/whisper-small",
device=DEVICE
)
# Configure the model with the appropriate options
# Note that for whisper, language should be set during inference, not initialization
logger.info("ASR pipeline loaded successfully")
socketio.emit('model_status', {'model': 'asr', 'status': 'loaded'})
except Exception as e:
logger.error(f"Error loading ASR pipeline: {str(e)}")
socketio.emit('model_status', {'model': 'asr', 'status': 'error', 'message': str(e)})
logger.info("Loading Llama 3.2 model...") logger.info("Loading Llama 3.2 model...")
models.llm = AutoModelForCausalLM.from_pretrained( try:
"meta-llama/Llama-3.2-1B", models.llm = AutoModelForCausalLM.from_pretrained(
device_map=DEVICE, "meta-llama/Llama-3.2-1B",
torch_dtype=torch.bfloat16 device_map=DEVICE,
) torch_dtype=torch.bfloat16
models.tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B") )
models.tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B")
logger.info("Llama 3.2 model loaded successfully")
socketio.emit('model_status', {'model': 'llm', 'status': 'loaded'})
except Exception as e:
logger.error(f"Error loading Llama 3.2 model: {str(e)}")
socketio.emit('model_status', {'model': 'llm', 'status': 'error', 'message': str(e)})
# Load models in a background thread # Load models in a background thread
threading.Thread(target=load_models, daemon=True).start() threading.Thread(target=load_models, daemon=True).start()
@@ -118,6 +140,20 @@ def health_check():
"models_loaded": models.generator is not None and models.llm is not None "models_loaded": models.generator is not None and models.llm is not None
}) })
# Add a system status endpoint
@app.route('/api/status')
def system_status():
return jsonify({
"status": "ok",
"cuda_available": torch.cuda.is_available(),
"device": DEVICE,
"models": {
"generator": models.generator is not None,
"asr": models.asr is not None,
"llm": models.llm is not None
}
})
# Socket event handlers # Socket event handlers
@socketio.on('connect') @socketio.on('connect')
def handle_connect(auth=None): def handle_connect(auth=None):
@@ -225,10 +261,12 @@ def process_audio_queue(session_id, q):
def process_audio_and_respond(session_id, data): def process_audio_and_respond(session_id, data):
"""Process audio data and generate a response""" """Process audio data and generate a response"""
if models.generator is None or models.asr is None or models.llm is None: if models.generator is None or models.asr is None or models.llm is None:
logger.warning("Models not yet loaded!")
with app.app_context(): with app.app_context():
socketio.emit('error', {'message': 'Models still loading, please wait'}, room=session_id) socketio.emit('error', {'message': 'Models still loading, please wait'}, room=session_id)
return return
logger.info(f"Processing audio for session {session_id}")
conversation = active_conversations[session_id] conversation = active_conversations[session_id]
try: try:
@@ -238,9 +276,15 @@ def process_audio_and_respond(session_id, data):
# Process base64 audio data # Process base64 audio data
audio_data = data['audio'] audio_data = data['audio']
speaker_id = data['speaker'] speaker_id = data['speaker']
logger.info(f"Received audio from speaker {speaker_id}")
# Convert from base64 to WAV # Convert from base64 to WAV
audio_bytes = base64.b64decode(audio_data.split(',')[1]) try:
audio_bytes = base64.b64decode(audio_data.split(',')[1])
logger.info(f"Decoded audio bytes: {len(audio_bytes)} bytes")
except Exception as e:
logger.error(f"Error decoding base64 audio: {str(e)}")
raise
# Save to temporary file for processing # Save to temporary file for processing
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as temp_file: with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as temp_file:
@@ -270,7 +314,8 @@ def process_audio_and_respond(session_id, data):
# Use the ASR pipeline to transcribe # Use the ASR pipeline to transcribe
transcription_result = models.asr( transcription_result = models.asr(
{"array": waveform.squeeze().cpu().numpy(), "sampling_rate": models.generator.sample_rate}, {"array": waveform.squeeze().cpu().numpy(), "sampling_rate": models.generator.sample_rate},
return_timestamps=False return_timestamps=False,
generate_kwargs={"language": "en"} # Set language during inference
) )
user_text = transcription_result['text'].strip() user_text = transcription_result['text'].strip()
@@ -308,11 +353,19 @@ def process_audio_and_respond(session_id, data):
prompt = f"{conversation_history}Assistant: " prompt = f"{conversation_history}Assistant: "
# Generate response with Llama # Generate response with Llama
input_ids = models.tokenizer(prompt, return_tensors="pt").input_ids.to(DEVICE) input_tokens = models.tokenizer(
prompt,
return_tensors="pt",
padding=True,
return_attention_mask=True
)
input_ids = input_tokens.input_ids.to(DEVICE)
attention_mask = input_tokens.attention_mask.to(DEVICE)
with torch.no_grad(): with torch.no_grad():
generated_ids = models.llm.generate( generated_ids = models.llm.generate(
input_ids, input_ids,
attention_mask=attention_mask, # Add the attention mask
max_new_tokens=100, max_new_tokens=100,
temperature=0.7, temperature=0.7,
top_p=0.9, top_p=0.9,
@@ -437,5 +490,6 @@ cleanup_thread.start()
# Start the server # Start the server
if __name__ == '__main__': if __name__ == '__main__':
port = int(os.environ.get('PORT', 5000)) port = int(os.environ.get('PORT', 5000))
logger.info(f"Starting server on port {port}") debug_mode = os.environ.get('DEBUG', 'False').lower() == 'true'
socketio.run(app, host='0.0.0.0', port=port, debug=False, allow_unsafe_werkzeug=True) logger.info(f"Starting server on port {port} (debug={debug_mode})")
socketio.run(app, host='0.0.0.0', port=port, debug=debug_mode, allow_unsafe_werkzeug=True)

View File

@@ -378,25 +378,39 @@ function handleSpeechState(isSilent) {
if (state.isSpeaking) { if (state.isSpeaking) {
state.isSpeaking = false; state.isSpeaking = false;
// Get the current audio data and send it try {
const audioBuffer = new Float32Array(state.audioContext.sampleRate * 5); // 5 seconds max // Get the current audio data and send it
state.analyser.getFloatTimeDomainData(audioBuffer); const audioBuffer = new Float32Array(state.audioContext.sampleRate * 5); // 5 seconds max
state.analyser.getFloatTimeDomainData(audioBuffer);
// Create WAV blob
const wavBlob = createWavBlob(audioBuffer, state.audioContext.sampleRate); // Check if audio has content
const hasAudioContent = audioBuffer.some(sample => Math.abs(sample) > 0.01);
// Convert to base64
const reader = new FileReader(); if (!hasAudioContent) {
reader.onloadend = function() { console.warn('Audio buffer appears to be empty or very quiet');
sendAudioChunk(reader.result, state.currentSpeaker); addSystemMessage('No speech detected. Please try again and speak clearly.');
}; return;
reader.readAsDataURL(wavBlob); }
// Update button state // Create WAV blob
elements.streamButton.classList.add('processing'); const wavBlob = createWavBlob(audioBuffer, state.audioContext.sampleRate);
elements.streamButton.innerHTML = '<i class="fas fa-cog fa-spin"></i> Processing...';
// Convert to base64
addSystemMessage('Processing your message...'); const reader = new FileReader();
reader.onloadend = function() {
sendAudioChunk(reader.result, state.currentSpeaker);
};
reader.readAsDataURL(wavBlob);
// Update button state
elements.streamButton.classList.add('processing');
elements.streamButton.innerHTML = '<i class="fas fa-cog fa-spin"></i> Processing...';
addSystemMessage('Processing your message...');
} catch (e) {
console.error('Error recording audio:', e);
addSystemMessage('Error recording audio. Please try again.');
}
} }
}, CLIENT_SILENCE_DURATION_MS); }, CLIENT_SILENCE_DURATION_MS);
} }