Files
HooHacks-12/Backend/server.py
2025-03-30 11:25:52 -04:00

691 lines
24 KiB
Python

import eventlet
eventlet.monkey_patch()
# Keep your other imports below
import os
import io
import base64
import time
import torch
import torchaudio
import numpy as np
from flask import Flask, render_template, request
from flask_socketio import SocketIO, emit
from transformers import AutoModelForCausalLM, AutoTokenizer
import threading
import queue
import requests
import huggingface_hub
from generator import load_csm_1b, Segment
from collections import deque
import json
import webrtcvad # For voice activity detection
# Configure environment with longer timeouts
os.environ["HF_HUB_DOWNLOAD_TIMEOUT"] = "600" # 10 minutes timeout for downloads
requests.adapters.DEFAULT_TIMEOUT = 60 # Increase default requests timeout
# Create a models directory for caching
os.makedirs("models", exist_ok=True)
app = Flask(__name__)
app.config['SECRET_KEY'] = 'your-secret-key'
socketio = SocketIO(app, cors_allowed_origins="*", async_mode='eventlet')
# Explicitly check for CUDA and print detailed info
print("\n=== CUDA Information ===")
if torch.cuda.is_available():
print(f"CUDA is available")
print(f"CUDA version: {torch.version.cuda}")
print(f"Number of GPUs: {torch.cuda.device_count()}")
for i in range(torch.cuda.device_count()):
print(f"GPU {i}: {torch.cuda.get_device_name(i)}")
else:
print("CUDA is not available")
# Check for cuDNN
try:
import ctypes
ctypes.CDLL("libcudnn_ops_infer.so.8")
print("cuDNN is available")
except:
print("cuDNN is not available (libcudnn_ops_infer.so.8 not found)")
# Determine compute device
try:
if torch.cuda.is_available():
device = "cuda"
whisper_compute_type = "float16"
print("🟢 CUDA is available and initialized successfully")
elif torch.backends.mps.is_available():
device = "mps"
whisper_compute_type = "float32"
print("🟢 MPS is available (Apple Silicon)")
else:
device = "cpu"
whisper_compute_type = "int8"
print("🟡 Using CPU (CUDA/MPS not available)")
except Exception as e:
print(f"🔴 Error initializing CUDA: {e}")
print("🔴 Falling back to CPU")
device = "cpu"
whisper_compute_type = "int8"
print(f"Using device: {device}")
# Initialize models with proper error handling
whisper_model = None
csm_generator = None
llm_model = None
llm_tokenizer = None
vad = None
# Constants
SAMPLE_RATE = 16000 # For VAD
VAD_FRAME_SIZE = 480 # 30ms at 16kHz for VAD
VAD_MODE = 3 # Aggressive mode for better results
AUDIO_CHUNK_SIZE = 2400 # 100ms chunks when streaming AI voice
# Audio sample rates
CLIENT_SAMPLE_RATE = 44100 # Browser WebAudio default
WHISPER_SAMPLE_RATE = 16000 # Whisper expects 16kHz
# Session data structures
user_sessions = {} # session_id -> complete session data
# WebRTC ICE servers (STUN/TURN servers for NAT traversal)
ICE_SERVERS = [
{"urls": "stun:stun.l.google.com:19302"},
{"urls": "stun:stun1.l.google.com:19302"}
]
def load_models():
"""Load all necessary models"""
global whisper_model, csm_generator, llm_model, llm_tokenizer, vad
# Initialize Voice Activity Detector
try:
vad = webrtcvad.Vad(VAD_MODE)
print("Voice Activity Detector initialized")
except Exception as e:
print(f"Error initializing VAD: {e}")
vad = None
# Initialize Faster-Whisper for transcription
try:
print("Loading Whisper model...")
from faster_whisper import WhisperModel
whisper_model = WhisperModel("base", device=device, compute_type=whisper_compute_type, download_root="./models/whisper")
print("Whisper model loaded successfully")
except Exception as e:
print(f"Error loading Whisper model: {e}")
print("Will use backup speech recognition method if available")
# Initialize CSM model for audio generation
try:
print("Loading CSM model...")
csm_generator = load_csm_1b(device=device)
print("CSM model loaded successfully")
except Exception as e:
print(f"Error loading CSM model: {e}")
print("Audio generation will not be available")
# Initialize Llama 3.2 model for response generation
try:
print("Loading Llama 3.2 model...")
llm_model_id = "meta-llama/Llama-3.2-1B"
llm_tokenizer = AutoTokenizer.from_pretrained(llm_model_id, cache_dir="./models/llama")
dtype = torch.bfloat16 if device != "cpu" else torch.float32
llm_model = AutoModelForCausalLM.from_pretrained(
llm_model_id,
torch_dtype=dtype,
device_map=device,
cache_dir="./models/llama",
low_cpu_mem_usage=True
)
print("Llama 3.2 model loaded successfully")
except Exception as e:
print(f"Error loading Llama 3.2 model: {e}")
print("Will use a fallback response generation method")
@app.route('/')
def index():
"""Serve the main interface"""
return render_template('index.html')
@app.route('/static/js/voice-chat.js')
def serve_voice_chat_js():
"""Serve the JavaScript file"""
return app.send_static_file('js/voice-chat.js')
@socketio.on('connect')
def handle_connect():
"""Handle new client connection"""
session_id = request.sid
print(f"Client connected: {session_id}")
# Initialize session data
user_sessions[session_id] = {
# Conversation context
'segments': [],
'conversation_history': [],
'is_turn_active': False,
# Audio buffers and state
'vad_buffer': deque(maxlen=30), # ~1s of audio at 30fps
'audio_buffer': bytearray(),
'is_user_speaking': False,
'last_vad_active': time.time(),
'silence_duration': 0,
'speech_frames': 0,
# AI state
'is_ai_speaking': False,
'should_interrupt_ai': False,
'ai_stream_queue': queue.Queue(),
# Processing flags
'is_processing': False,
'pending_user_audio': None
}
# Send config to client
emit('session_ready', {
'whisper_available': whisper_model is not None,
'csm_available': csm_generator is not None,
'llm_available': llm_model is not None,
'client_sample_rate': CLIENT_SAMPLE_RATE,
'server_sample_rate': getattr(csm_generator, 'sample_rate', 24000) if csm_generator else 24000
})
emit('ready_for_speech', {'message': 'Ready to start conversation'})
@socketio.on('disconnect')
def handle_disconnect():
"""Handle client disconnection"""
session_id = request.sid
print(f"Client disconnected: {session_id}")
# Clean up resources
if session_id in user_sessions:
# Signal any running threads to stop
user_sessions[session_id]['should_interrupt_ai'] = True
# Clean up resources
del user_sessions[session_id]
@socketio.on('webrtc_signal')
def handle_webrtc_signal(data):
"""Handle WebRTC signaling for P2P connection establishment"""
session_id = request.sid
if session_id not in user_sessions:
return
# Simply relay the signal to the client
# In a multi-user app, we would route this to the correct peer
emit('webrtc_signal', data)
@socketio.on('webrtc_connected')
def handle_webrtc_connected(data):
"""Client notifies that WebRTC connection is established"""
session_id = request.sid
if session_id not in user_sessions:
return
user_sessions[session_id]['webrtc_connected'] = True
print(f"WebRTC connected for session {session_id}")
emit('ready_for_speech', {'message': 'Ready to start conversation'})
@socketio.on('audio_stream')
def handle_audio_stream(data):
"""Process incoming audio stream packets from client"""
session_id = request.sid
if session_id not in user_sessions:
return
session = user_sessions[session_id]
try:
# Decode audio data
audio_bytes = base64.b64decode(data.get('audio', ''))
if not audio_bytes or len(audio_bytes) < 2: # Need at least one sample
return
# Add to current audio buffer
session['audio_buffer'] += audio_bytes
# Check for speech using VAD
has_speech = detect_speech(audio_bytes, session_id)
# Handle speech state machine
if has_speech:
# Reset silence tracking when speech is detected
session['last_vad_active'] = time.time()
session['silence_duration'] = 0
session['speech_frames'] += 1
# If not already marked as speaking and we have enough speech frames
if not session['is_user_speaking'] and session['speech_frames'] >= 5:
on_speech_started(session_id)
else:
# No speech detected in this frame
if session['is_user_speaking']:
# Calculate silence duration
now = time.time()
session['silence_duration'] = now - session['last_vad_active']
# If silent for more than 0.5 seconds, end speech segment
if session['silence_duration'] > 0.8 and session['speech_frames'] > 8:
on_speech_ended(session_id)
else:
# Not speaking and no speech, just a silent frame
session['speech_frames'] = max(0, session['speech_frames'] - 1)
except Exception as e:
print(f"Error processing audio stream: {e}")
def detect_speech(audio_bytes, session_id):
"""Use VAD to check if audio contains speech"""
if session_id not in user_sessions:
return False
session = user_sessions[session_id]
# Store in VAD buffer for history
session['vad_buffer'].append(audio_bytes)
if vad is None:
# Fallback to simple energy detection
audio_data = np.frombuffer(audio_bytes, dtype=np.int16)
energy = np.mean(np.abs(audio_data)) / 32768.0
return energy > 0.015 # Simple threshold
try:
# Ensure we have the right amount of data for VAD
audio_data = np.frombuffer(audio_bytes, dtype=np.int16)
# If we have too much data, use just the right amount
if len(audio_data) >= VAD_FRAME_SIZE:
frame = audio_data[:VAD_FRAME_SIZE].tobytes()
return vad.is_speech(frame, SAMPLE_RATE)
# If too little data, accumulate in the VAD buffer and check periodically
if len(session['vad_buffer']) >= 3:
# Combine recent chunks to get enough data
combined = bytearray()
for chunk in list(session['vad_buffer'])[-3:]:
combined.extend(chunk)
# Extract the right amount of data
if len(combined) >= VAD_FRAME_SIZE:
frame = combined[:VAD_FRAME_SIZE]
return vad.is_speech(bytes(frame), SAMPLE_RATE)
return False
except Exception as e:
print(f"VAD error: {e}")
return False
def on_speech_started(session_id):
"""Handle start of user speech"""
if session_id not in user_sessions:
return
session = user_sessions[session_id]
# Reset audio buffer
session['audio_buffer'] = bytearray()
session['is_user_speaking'] = True
session['is_turn_active'] = True
# If AI is speaking, we need to interrupt it
if session['is_ai_speaking']:
session['should_interrupt_ai'] = True
socketio.emit('ai_interrupted_by_user', room=session_id)
# Notify client that we detected speech
socketio.emit('user_speech_start', room=session_id)
def on_speech_ended(session_id):
"""Handle end of user speech segment"""
if session_id not in user_sessions:
return
session = user_sessions[session_id]
# Mark as not speaking anymore
session['is_user_speaking'] = False
session['speech_frames'] = 0
# If no audio or already processing, skip
if len(session['audio_buffer']) < 4000 or session['is_processing']: # At least 250ms of audio
session['audio_buffer'] = bytearray()
return
# Mark as processing to prevent multiple processes
session['is_processing'] = True
# Create a copy of the audio buffer
audio_copy = session['audio_buffer']
session['audio_buffer'] = bytearray()
# Convert audio to the format needed for processing
try:
# Convert to float32 between -1 and 1
audio_np = np.frombuffer(audio_copy, dtype=np.int16).astype(np.float32) / 32768.0
audio_tensor = torch.from_numpy(audio_np)
# Resample to Whisper's expected sample rate if necessary
if CLIENT_SAMPLE_RATE != WHISPER_SAMPLE_RATE:
audio_tensor = torchaudio.functional.resample(
audio_tensor,
orig_freq=CLIENT_SAMPLE_RATE,
new_freq=WHISPER_SAMPLE_RATE
)
# Save as WAV for transcription
temp_audio_path = f"temp_audio_{session_id}.wav"
torchaudio.save(
temp_audio_path,
audio_tensor.unsqueeze(0),
WHISPER_SAMPLE_RATE
)
# Start transcription and response process in a thread
threading.Thread(
target=process_user_utterance,
args=(session_id, temp_audio_path, audio_tensor),
daemon=True
).start()
# Notify client that processing has started
socketio.emit('processing_speech', room=session_id)
except Exception as e:
print(f"Error preparing audio: {e}")
session['is_processing'] = False
socketio.emit('error', {'message': f'Error processing audio: {str(e)}'}, room=session_id)
def process_user_utterance(session_id, audio_path, audio_tensor):
"""Process user utterance, transcribe and generate response"""
if session_id not in user_sessions:
return
session = user_sessions[session_id]
try:
# Transcribe audio
if whisper_model is not None:
user_text = transcribe_with_whisper(audio_path)
else:
# Fallback to another transcription service
user_text = transcribe_fallback(audio_path)
# Clean up temp file
if os.path.exists(audio_path):
os.remove(audio_path)
# Check if we got meaningful text
if not user_text or len(user_text.strip()) < 2:
socketio.emit('no_speech_detected', room=session_id) # CHANGED: emit → socketio.emit
session['is_processing'] = False
return
print(f"Transcribed: {user_text}")
# Create user segment
user_segment = Segment(
text=user_text,
speaker=0, # User is speaker 0
audio=audio_tensor
)
session['segments'].append(user_segment)
# Update conversation history
session['conversation_history'].append({
'role': 'user',
'text': user_text
})
# Send transcription to client
socketio.emit('transcription', {'text': user_text}, room=session_id) # CHANGED: emit → socketio.emit
# Generate AI response
ai_response = generate_ai_response(user_text, session_id)
# Send text response to client
socketio.emit('ai_response_text', {'text': ai_response}, room=session_id) # CHANGED: emit → socketio.emit
# Update conversation history
session['conversation_history'].append({
'role': 'assistant',
'text': ai_response
})
# Generate voice response if CSM is available
if csm_generator is not None:
session['is_ai_speaking'] = True
session['should_interrupt_ai'] = False
# Begin streaming audio response
threading.Thread(
target=stream_ai_response,
args=(ai_response, session_id),
daemon=True
).start()
except Exception as e:
print(f"Error processing utterance: {e}")
socketio.emit('error', {'message': f'Error: {str(e)}'}, room=session_id) # CHANGED: emit → socketio.emit
finally:
# Clear processing flag
if session_id in user_sessions:
session['is_processing'] = False
def transcribe_with_whisper(audio_path):
"""Transcribe audio using Faster-Whisper"""
segments, info = whisper_model.transcribe(audio_path, beam_size=5)
# Collect all text from segments
user_text = ""
for segment in segments:
user_text += segment.text.strip() + " "
return user_text.strip()
def transcribe_fallback(audio_path):
"""Fallback transcription using Google's speech recognition"""
try:
import speech_recognition as sr
recognizer = sr.Recognizer()
with sr.AudioFile(audio_path) as source:
audio = recognizer.record(source)
try:
text = recognizer.recognize_google(audio)
return text
except sr.UnknownValueError:
return ""
except sr.RequestError:
return "[Speech recognition service unavailable]"
except ImportError:
return "[Speech recognition not available]"
def generate_ai_response(user_text, session_id):
"""Generate text response using available LLM"""
if session_id not in user_sessions:
return "I'm sorry, your session has expired."
session = user_sessions[session_id]
if llm_model is not None and llm_tokenizer is not None:
# Format conversation history for the LLM
prompt = "You are a helpful, friendly voice assistant. Keep your responses brief and conversational.\n\n"
# Add recent conversation history (last 6 turns maximum)
for entry in session['conversation_history'][-6:]:
if entry['role'] == 'user':
prompt += f"User: {entry['text']}\n"
else:
prompt += f"Assistant: {entry['text']}\n"
# Add current query if not already in history
if not session['conversation_history'] or session['conversation_history'][-1]['role'] != 'user':
prompt += f"User: {user_text}\n"
prompt += "Assistant: "
try:
# Generate response
inputs = llm_tokenizer(prompt, return_tensors="pt").to(device)
output = llm_model.generate(
inputs.input_ids,
max_new_tokens=100, # Keep responses shorter for voice
temperature=0.7,
top_p=0.9,
do_sample=True
)
response = llm_tokenizer.decode(output[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
return response.strip()
except Exception as e:
print(f"Error generating LLM response: {e}")
return fallback_response(user_text)
else:
return fallback_response(user_text)
def fallback_response(user_text):
"""Generate simple fallback responses when LLM is unavailable"""
user_text_lower = user_text.lower()
if "hello" in user_text_lower or "hi" in user_text_lower:
return "Hello! How can I help you today?"
elif "how are you" in user_text_lower:
return "I'm doing well, thanks for asking! How about you?"
elif "thank" in user_text_lower:
return "You're welcome! Happy to help."
elif "bye" in user_text_lower or "goodbye" in user_text_lower:
return "Goodbye! Have a great day!"
elif any(q in user_text_lower for q in ["what", "who", "where", "when", "why", "how"]):
return "That's an interesting question. I wish I could provide a better answer in my current fallback mode."
else:
return "I see. Tell me more about that."
def stream_ai_response(text, session_id):
"""Generate and stream audio response in real-time chunks"""
if session_id not in user_sessions:
return
session = user_sessions[session_id]
try:
# Signal start of AI speech
emit('ai_speech_start', room=session_id)
# Use the last few conversation segments as context (up to 4)
context_segments = session['segments'][-4:] if len(session['segments']) > 4 else session['segments']
# Generate audio for bot response
audio = csm_generator.generate(
text=text,
speaker=1, # Bot is speaker 1
context=context_segments,
max_audio_length_ms=10000, # 10 seconds max
temperature=0.9,
topk=50
)
# Create and store bot segment
bot_segment = Segment(
text=text,
speaker=1,
audio=audio
)
if session_id in user_sessions:
session['segments'].append(bot_segment)
# Stream audio in small chunks for more responsive playback
chunk_size = AUDIO_CHUNK_SIZE # Size defined in constants
for i in range(0, len(audio), chunk_size):
# Check if we should stop (user interrupted)
if session_id not in user_sessions or session['should_interrupt_ai']:
print("AI speech interrupted")
break
# Get next chunk
chunk = audio[i:i+chunk_size]
# Convert audio chunk to base64 for streaming
audio_bytes = io.BytesIO()
torchaudio.save(audio_bytes, chunk.unsqueeze(0).cpu(), csm_generator.sample_rate, format="wav")
audio_bytes.seek(0)
audio_b64 = base64.b64encode(audio_bytes.read()).decode('utf-8')
# Send chunk to client
socketio.emit('ai_speech_chunk', {
'audio': audio_b64,
'is_last': i + chunk_size >= len(audio)
}, room=session_id)
# Small sleep for more natural pacing
time.sleep(0.06) # Slight delay for smoother playback
# Signal end of AI speech
if session_id in user_sessions:
session['is_ai_speaking'] = False
session['is_turn_active'] = False # End conversation turn
socketio.emit('ai_speech_end', room=session_id)
except Exception as e:
print(f"Error streaming AI response: {e}")
if session_id in user_sessions:
session['is_ai_speaking'] = False
session['is_turn_active'] = False
socketio.emit('error', {'message': f'Error generating audio: {str(e)}'}, room=session_id)
socketio.emit('ai_speech_end', room=session_id)
@socketio.on('interrupt_ai')
def handle_interrupt():
"""Handle explicit AI interruption request from client"""
session_id = request.sid
if session_id in user_sessions:
user_sessions[session_id]['should_interrupt_ai'] = True
emit('ai_interrupted', room=session_id)
@socketio.on('get_config')
def handle_get_config():
"""Send configuration to client"""
session_id = request.sid
if session_id in user_sessions:
emit('config', {
'client_sample_rate': CLIENT_SAMPLE_RATE,
'server_sample_rate': getattr(csm_generator, 'sample_rate', 24000) if csm_generator else 24000,
'whisper_available': whisper_model is not None,
'csm_available': csm_generator is not None,
'ice_servers': ICE_SERVERS
})
if __name__ == '__main__':
# Ensure the existing index.html file is in the correct location
if not os.path.exists('templates'):
os.makedirs('templates')
if os.path.exists('index.html') and not os.path.exists('templates/index.html'):
os.rename('index.html', 'templates/index.html')
# Load models before starting the server
print("Starting model loading...")
load_models()
# Start the server
print("Starting Flask SocketIO server...")
socketio.run(app, host='0.0.0.0', port=5000, debug=False)