Merge branch 'main' of https://github.com/GamerBoss101/HooHacks-12
This commit is contained in:
@@ -61,22 +61,44 @@ def load_models():
|
|||||||
global models
|
global models
|
||||||
|
|
||||||
logger.info("Loading CSM 1B model...")
|
logger.info("Loading CSM 1B model...")
|
||||||
models.generator = load_csm_1b(device=DEVICE)
|
try:
|
||||||
|
models.generator = load_csm_1b(device=DEVICE)
|
||||||
|
logger.info("CSM 1B model loaded successfully")
|
||||||
|
socketio.emit('model_status', {'model': 'csm', 'status': 'loaded'})
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error loading CSM 1B model: {str(e)}")
|
||||||
|
socketio.emit('model_status', {'model': 'csm', 'status': 'error', 'message': str(e)})
|
||||||
|
|
||||||
logger.info("Loading ASR pipeline...")
|
logger.info("Loading ASR pipeline...")
|
||||||
models.asr = pipeline(
|
try:
|
||||||
"automatic-speech-recognition",
|
# Initialize the pipeline without the language parameter in the constructor
|
||||||
model="openai/whisper-small",
|
models.asr = pipeline(
|
||||||
device=DEVICE
|
"automatic-speech-recognition",
|
||||||
)
|
model="openai/whisper-small",
|
||||||
|
device=DEVICE
|
||||||
|
)
|
||||||
|
|
||||||
|
# Configure the model with the appropriate options
|
||||||
|
# Note that for whisper, language should be set during inference, not initialization
|
||||||
|
logger.info("ASR pipeline loaded successfully")
|
||||||
|
socketio.emit('model_status', {'model': 'asr', 'status': 'loaded'})
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error loading ASR pipeline: {str(e)}")
|
||||||
|
socketio.emit('model_status', {'model': 'asr', 'status': 'error', 'message': str(e)})
|
||||||
|
|
||||||
logger.info("Loading Llama 3.2 model...")
|
logger.info("Loading Llama 3.2 model...")
|
||||||
models.llm = AutoModelForCausalLM.from_pretrained(
|
try:
|
||||||
"meta-llama/Llama-3.2-1B",
|
models.llm = AutoModelForCausalLM.from_pretrained(
|
||||||
device_map=DEVICE,
|
"meta-llama/Llama-3.2-1B",
|
||||||
torch_dtype=torch.bfloat16
|
device_map=DEVICE,
|
||||||
)
|
torch_dtype=torch.bfloat16
|
||||||
models.tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B")
|
)
|
||||||
|
models.tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B")
|
||||||
|
logger.info("Llama 3.2 model loaded successfully")
|
||||||
|
socketio.emit('model_status', {'model': 'llm', 'status': 'loaded'})
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error loading Llama 3.2 model: {str(e)}")
|
||||||
|
socketio.emit('model_status', {'model': 'llm', 'status': 'error', 'message': str(e)})
|
||||||
|
|
||||||
# Load models in a background thread
|
# Load models in a background thread
|
||||||
threading.Thread(target=load_models, daemon=True).start()
|
threading.Thread(target=load_models, daemon=True).start()
|
||||||
@@ -118,6 +140,20 @@ def health_check():
|
|||||||
"models_loaded": models.generator is not None and models.llm is not None
|
"models_loaded": models.generator is not None and models.llm is not None
|
||||||
})
|
})
|
||||||
|
|
||||||
|
# Add a system status endpoint
|
||||||
|
@app.route('/api/status')
|
||||||
|
def system_status():
|
||||||
|
return jsonify({
|
||||||
|
"status": "ok",
|
||||||
|
"cuda_available": torch.cuda.is_available(),
|
||||||
|
"device": DEVICE,
|
||||||
|
"models": {
|
||||||
|
"generator": models.generator is not None,
|
||||||
|
"asr": models.asr is not None,
|
||||||
|
"llm": models.llm is not None
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
# Socket event handlers
|
# Socket event handlers
|
||||||
@socketio.on('connect')
|
@socketio.on('connect')
|
||||||
def handle_connect(auth=None):
|
def handle_connect(auth=None):
|
||||||
@@ -225,10 +261,12 @@ def process_audio_queue(session_id, q):
|
|||||||
def process_audio_and_respond(session_id, data):
|
def process_audio_and_respond(session_id, data):
|
||||||
"""Process audio data and generate a response"""
|
"""Process audio data and generate a response"""
|
||||||
if models.generator is None or models.asr is None or models.llm is None:
|
if models.generator is None or models.asr is None or models.llm is None:
|
||||||
|
logger.warning("Models not yet loaded!")
|
||||||
with app.app_context():
|
with app.app_context():
|
||||||
socketio.emit('error', {'message': 'Models still loading, please wait'}, room=session_id)
|
socketio.emit('error', {'message': 'Models still loading, please wait'}, room=session_id)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
logger.info(f"Processing audio for session {session_id}")
|
||||||
conversation = active_conversations[session_id]
|
conversation = active_conversations[session_id]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -238,9 +276,15 @@ def process_audio_and_respond(session_id, data):
|
|||||||
# Process base64 audio data
|
# Process base64 audio data
|
||||||
audio_data = data['audio']
|
audio_data = data['audio']
|
||||||
speaker_id = data['speaker']
|
speaker_id = data['speaker']
|
||||||
|
logger.info(f"Received audio from speaker {speaker_id}")
|
||||||
|
|
||||||
# Convert from base64 to WAV
|
# Convert from base64 to WAV
|
||||||
audio_bytes = base64.b64decode(audio_data.split(',')[1])
|
try:
|
||||||
|
audio_bytes = base64.b64decode(audio_data.split(',')[1])
|
||||||
|
logger.info(f"Decoded audio bytes: {len(audio_bytes)} bytes")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error decoding base64 audio: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
# Save to temporary file for processing
|
# Save to temporary file for processing
|
||||||
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as temp_file:
|
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as temp_file:
|
||||||
@@ -270,7 +314,8 @@ def process_audio_and_respond(session_id, data):
|
|||||||
# Use the ASR pipeline to transcribe
|
# Use the ASR pipeline to transcribe
|
||||||
transcription_result = models.asr(
|
transcription_result = models.asr(
|
||||||
{"array": waveform.squeeze().cpu().numpy(), "sampling_rate": models.generator.sample_rate},
|
{"array": waveform.squeeze().cpu().numpy(), "sampling_rate": models.generator.sample_rate},
|
||||||
return_timestamps=False
|
return_timestamps=False,
|
||||||
|
generate_kwargs={"language": "en"} # Set language during inference
|
||||||
)
|
)
|
||||||
user_text = transcription_result['text'].strip()
|
user_text = transcription_result['text'].strip()
|
||||||
|
|
||||||
@@ -308,11 +353,19 @@ def process_audio_and_respond(session_id, data):
|
|||||||
prompt = f"{conversation_history}Assistant: "
|
prompt = f"{conversation_history}Assistant: "
|
||||||
|
|
||||||
# Generate response with Llama
|
# Generate response with Llama
|
||||||
input_ids = models.tokenizer(prompt, return_tensors="pt").input_ids.to(DEVICE)
|
input_tokens = models.tokenizer(
|
||||||
|
prompt,
|
||||||
|
return_tensors="pt",
|
||||||
|
padding=True,
|
||||||
|
return_attention_mask=True
|
||||||
|
)
|
||||||
|
input_ids = input_tokens.input_ids.to(DEVICE)
|
||||||
|
attention_mask = input_tokens.attention_mask.to(DEVICE)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
generated_ids = models.llm.generate(
|
generated_ids = models.llm.generate(
|
||||||
input_ids,
|
input_ids,
|
||||||
|
attention_mask=attention_mask, # Add the attention mask
|
||||||
max_new_tokens=100,
|
max_new_tokens=100,
|
||||||
temperature=0.7,
|
temperature=0.7,
|
||||||
top_p=0.9,
|
top_p=0.9,
|
||||||
@@ -437,5 +490,6 @@ cleanup_thread.start()
|
|||||||
# Start the server
|
# Start the server
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
port = int(os.environ.get('PORT', 5000))
|
port = int(os.environ.get('PORT', 5000))
|
||||||
logger.info(f"Starting server on port {port}")
|
debug_mode = os.environ.get('DEBUG', 'False').lower() == 'true'
|
||||||
socketio.run(app, host='0.0.0.0', port=port, debug=False, allow_unsafe_werkzeug=True)
|
logger.info(f"Starting server on port {port} (debug={debug_mode})")
|
||||||
|
socketio.run(app, host='0.0.0.0', port=port, debug=debug_mode, allow_unsafe_werkzeug=True)
|
||||||
@@ -378,25 +378,39 @@ function handleSpeechState(isSilent) {
|
|||||||
if (state.isSpeaking) {
|
if (state.isSpeaking) {
|
||||||
state.isSpeaking = false;
|
state.isSpeaking = false;
|
||||||
|
|
||||||
// Get the current audio data and send it
|
try {
|
||||||
const audioBuffer = new Float32Array(state.audioContext.sampleRate * 5); // 5 seconds max
|
// Get the current audio data and send it
|
||||||
state.analyser.getFloatTimeDomainData(audioBuffer);
|
const audioBuffer = new Float32Array(state.audioContext.sampleRate * 5); // 5 seconds max
|
||||||
|
state.analyser.getFloatTimeDomainData(audioBuffer);
|
||||||
|
|
||||||
// Create WAV blob
|
// Check if audio has content
|
||||||
const wavBlob = createWavBlob(audioBuffer, state.audioContext.sampleRate);
|
const hasAudioContent = audioBuffer.some(sample => Math.abs(sample) > 0.01);
|
||||||
|
|
||||||
// Convert to base64
|
if (!hasAudioContent) {
|
||||||
const reader = new FileReader();
|
console.warn('Audio buffer appears to be empty or very quiet');
|
||||||
reader.onloadend = function() {
|
addSystemMessage('No speech detected. Please try again and speak clearly.');
|
||||||
sendAudioChunk(reader.result, state.currentSpeaker);
|
return;
|
||||||
};
|
}
|
||||||
reader.readAsDataURL(wavBlob);
|
|
||||||
|
|
||||||
// Update button state
|
// Create WAV blob
|
||||||
elements.streamButton.classList.add('processing');
|
const wavBlob = createWavBlob(audioBuffer, state.audioContext.sampleRate);
|
||||||
elements.streamButton.innerHTML = '<i class="fas fa-cog fa-spin"></i> Processing...';
|
|
||||||
|
|
||||||
addSystemMessage('Processing your message...');
|
// Convert to base64
|
||||||
|
const reader = new FileReader();
|
||||||
|
reader.onloadend = function() {
|
||||||
|
sendAudioChunk(reader.result, state.currentSpeaker);
|
||||||
|
};
|
||||||
|
reader.readAsDataURL(wavBlob);
|
||||||
|
|
||||||
|
// Update button state
|
||||||
|
elements.streamButton.classList.add('processing');
|
||||||
|
elements.streamButton.innerHTML = '<i class="fas fa-cog fa-spin"></i> Processing...';
|
||||||
|
|
||||||
|
addSystemMessage('Processing your message...');
|
||||||
|
} catch (e) {
|
||||||
|
console.error('Error recording audio:', e);
|
||||||
|
addSystemMessage('Error recording audio. Please try again.');
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}, CLIENT_SILENCE_DURATION_MS);
|
}, CLIENT_SILENCE_DURATION_MS);
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user