Files
HooHacks-12/Backend/server.py
2025-03-29 23:28:44 -04:00

476 lines
17 KiB
Python

import os
import base64
import json
import torch
import torchaudio
import numpy as np
import whisperx
from io import BytesIO
from typing import List, Dict, Any, Optional
from flask import Flask, request, send_from_directory, Response
from flask_cors import CORS
from flask_socketio import SocketIO, emit, disconnect
from generator import load_csm_1b, Segment
import time
import gc
from collections import deque
from threading import Lock
# Select device
if torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
print(f"Using device: {device}")
# Initialize the model
generator = load_csm_1b(device=device)
# Initialize WhisperX for ASR
print("Loading WhisperX model...")
# Use a smaller model for faster response times
asr_model = whisperx.load_model("medium", device, compute_type="float16")
print("WhisperX model loaded!")
# Silence detection parameters
SILENCE_THRESHOLD = 0.01 # Adjust based on your audio normalization
SILENCE_DURATION_SEC = 1.0 # How long silence must persist
# Define the base directory
base_dir = os.path.dirname(os.path.abspath(__file__))
static_dir = os.path.join(base_dir, "static")
os.makedirs(static_dir, exist_ok=True)
# Setup Flask
app = Flask(__name__)
CORS(app)
socketio = SocketIO(app, cors_allowed_origins="*", async_mode='eventlet')
# Socket connection management
thread = None
thread_lock = Lock()
active_clients = {} # Map client_id to client context
# Helper function to convert audio data
def decode_audio_data(audio_data: str) -> torch.Tensor:
"""Decode base64 audio data to a torch tensor"""
try:
# Extract the actual base64 content
if ',' in audio_data:
audio_data = audio_data.split(',')[1]
# Decode base64 audio data
binary_data = base64.b64decode(audio_data)
# Load audio from binary data
with BytesIO(binary_data) as temp_file:
audio_tensor, sample_rate = torchaudio.load(temp_file, format="wav")
# Resample if needed
if sample_rate != generator.sample_rate:
audio_tensor = torchaudio.functional.resample(
audio_tensor.squeeze(0),
orig_freq=sample_rate,
new_freq=generator.sample_rate
)
else:
audio_tensor = audio_tensor.squeeze(0)
return audio_tensor
except Exception as e:
print(f"Error decoding audio: {str(e)}")
# Return a small silent audio segment as fallback
return torch.zeros(generator.sample_rate // 2) # 0.5 seconds of silence
def encode_audio_data(audio_tensor: torch.Tensor) -> str:
"""Encode torch tensor audio to base64 string"""
buf = BytesIO()
torchaudio.save(buf, audio_tensor.unsqueeze(0).cpu(), generator.sample_rate, format="wav")
buf.seek(0)
audio_base64 = base64.b64encode(buf.read()).decode('utf-8')
return f"data:audio/wav;base64,{audio_base64}"
def transcribe_audio(audio_tensor: torch.Tensor) -> str:
"""Transcribe audio using WhisperX"""
try:
# Save the tensor to a temporary file
temp_path = os.path.join(base_dir, "temp_audio.wav")
torchaudio.save(temp_path, audio_tensor.unsqueeze(0).cpu(), generator.sample_rate)
# Load and transcribe the audio
audio = whisperx.load_audio(temp_path)
result = asr_model.transcribe(audio, batch_size=16)
# Clean up
if os.path.exists(temp_path):
os.remove(temp_path)
# Get the transcription text
if result["segments"] and len(result["segments"]) > 0:
# Combine all segments
transcription = " ".join([segment["text"] for segment in result["segments"]])
return transcription.strip()
else:
return ""
except Exception as e:
print(f"Error in transcription: {str(e)}")
if os.path.exists("temp_audio.wav"):
os.remove("temp_audio.wav")
return ""
def generate_response(text: str, conversation_history: List[Segment]) -> str:
"""Generate a contextual response based on the transcribed text"""
# Simple response logic - can be replaced with a more sophisticated LLM in the future
responses = {
"hello": "Hello there! How are you doing today?",
"how are you": "I'm doing well, thanks for asking! How about you?",
"what is your name": "I'm Sesame, your voice assistant. How can I help you?",
"bye": "Goodbye! It was nice chatting with you.",
"thank you": "You're welcome! Is there anything else I can help with?",
"weather": "I don't have real-time weather data, but I hope it's nice where you are!",
"help": "I can chat with you using natural voice. Just speak normally and I'll respond.",
}
text_lower = text.lower()
# Check for matching keywords
for key, response in responses.items():
if key in text_lower:
return response
# Default responses based on text length
if not text:
return "I didn't catch that. Could you please repeat?"
elif len(text) < 10:
return "Thanks for your message. Could you elaborate a bit more?"
else:
return f"I understand you said '{text}'. That's interesting! Can you tell me more about that?"
# Flask routes for serving static content
@app.route('/')
def index():
return send_from_directory(base_dir, 'index.html')
@app.route('/favicon.ico')
def favicon():
if os.path.exists(os.path.join(static_dir, 'favicon.ico')):
return send_from_directory(static_dir, 'favicon.ico')
return Response(status=204)
@app.route('/voice-chat.js')
def voice_chat_js():
return send_from_directory(base_dir, 'voice-chat.js')
@app.route('/static/<path:path>')
def serve_static(path):
return send_from_directory(static_dir, path)
# Socket.IO event handlers
@socketio.on('connect')
def handle_connect():
client_id = request.sid
print(f"Client connected: {client_id}")
# Initialize client context
active_clients[client_id] = {
'context_segments': [],
'streaming_buffer': [],
'is_streaming': False,
'is_silence': False,
'last_active_time': time.time(),
'energy_window': deque(maxlen=10)
}
emit('status', {'type': 'connected', 'message': 'Connected to server'})
@socketio.on('disconnect')
def handle_disconnect():
client_id = request.sid
if client_id in active_clients:
del active_clients[client_id]
print(f"Client disconnected: {client_id}")
@socketio.on('generate')
def handle_generate(data):
client_id = request.sid
if client_id not in active_clients:
emit('error', {'message': 'Client not registered'})
return
try:
text = data.get('text', '')
speaker_id = data.get('speaker', 0)
print(f"Generating audio for: '{text}' with speaker {speaker_id}")
# Generate audio response
audio_tensor = generator.generate(
text=text,
speaker=speaker_id,
context=active_clients[client_id]['context_segments'],
max_audio_length_ms=10_000,
)
# Add to conversation context
active_clients[client_id]['context_segments'].append(
Segment(text=text, speaker=speaker_id, audio=audio_tensor)
)
# Convert audio to base64 and send back to client
audio_base64 = encode_audio_data(audio_tensor)
emit('audio_response', {
'type': 'audio_response',
'audio': audio_base64
})
except Exception as e:
print(f"Error generating audio: {str(e)}")
emit('error', {
'type': 'error',
'message': f"Error generating audio: {str(e)}"
})
@socketio.on('add_to_context')
def handle_add_to_context(data):
client_id = request.sid
if client_id not in active_clients:
emit('error', {'message': 'Client not registered'})
return
try:
text = data.get('text', '')
speaker_id = data.get('speaker', 0)
audio_data = data.get('audio', '')
# Convert received audio to tensor
audio_tensor = decode_audio_data(audio_data)
# Add to conversation context
active_clients[client_id]['context_segments'].append(
Segment(text=text, speaker=speaker_id, audio=audio_tensor)
)
emit('context_updated', {
'type': 'context_updated',
'message': 'Audio added to context'
})
except Exception as e:
print(f"Error adding to context: {str(e)}")
emit('error', {
'type': 'error',
'message': f"Error processing audio: {str(e)}"
})
@socketio.on('clear_context')
def handle_clear_context():
client_id = request.sid
if client_id in active_clients:
active_clients[client_id]['context_segments'] = []
emit('context_updated', {
'type': 'context_updated',
'message': 'Context cleared'
})
@socketio.on('stream_audio')
def handle_stream_audio(data):
client_id = request.sid
if client_id not in active_clients:
emit('error', {'message': 'Client not registered'})
return
client = active_clients[client_id]
try:
speaker_id = data.get('speaker', 0)
audio_data = data.get('audio', '')
# Convert received audio to tensor
audio_chunk = decode_audio_data(audio_data)
# Start streaming mode if not already started
if not client['is_streaming']:
client['is_streaming'] = True
client['streaming_buffer'] = []
client['energy_window'].clear()
client['is_silence'] = False
client['last_active_time'] = time.time()
print(f"[{client_id}] Streaming started with speaker ID: {speaker_id}")
emit('streaming_status', {
'type': 'streaming_status',
'status': 'started'
})
# Calculate audio energy for silence detection
chunk_energy = torch.mean(torch.abs(audio_chunk)).item()
client['energy_window'].append(chunk_energy)
avg_energy = sum(client['energy_window']) / len(client['energy_window'])
# Check if audio is silent
current_silence = avg_energy < SILENCE_THRESHOLD
# Track silence transition
if not client['is_silence'] and current_silence:
# Transition to silence
client['is_silence'] = True
client['last_active_time'] = time.time()
elif client['is_silence'] and not current_silence:
# User started talking again
client['is_silence'] = False
# Add chunk to buffer regardless of silence state
client['streaming_buffer'].append(audio_chunk)
# Check if silence has persisted long enough to consider "stopped talking"
silence_elapsed = time.time() - client['last_active_time']
if client['is_silence'] and silence_elapsed >= SILENCE_DURATION_SEC and len(client['streaming_buffer']) > 0:
# User has stopped talking - process the collected audio
print(f"[{client_id}] Processing audio after {silence_elapsed:.2f}s of silence")
full_audio = torch.cat(client['streaming_buffer'], dim=0)
# Process with WhisperX speech-to-text
print(f"[{client_id}] Starting transcription with WhisperX...")
transcribed_text = transcribe_audio(full_audio)
# Log the transcription
print(f"[{client_id}] Transcribed text: '{transcribed_text}'")
# Add to conversation context
if transcribed_text:
user_segment = Segment(text=transcribed_text, speaker=speaker_id, audio=full_audio)
client['context_segments'].append(user_segment)
# Generate a contextual response
response_text = generate_response(transcribed_text, client['context_segments'])
# Send the transcribed text to client
emit('transcription', {
'type': 'transcription',
'text': transcribed_text
})
# Generate audio for the response
audio_tensor = generator.generate(
text=response_text,
speaker=1 if speaker_id == 0 else 0, # Use opposite speaker
context=client['context_segments'],
max_audio_length_ms=10_000,
)
# Add response to context
ai_segment = Segment(
text=response_text,
speaker=1 if speaker_id == 0 else 0,
audio=audio_tensor
)
client['context_segments'].append(ai_segment)
# Convert audio to base64 and send back to client
audio_base64 = encode_audio_data(audio_tensor)
emit('audio_response', {
'type': 'audio_response',
'text': response_text,
'audio': audio_base64
})
else:
# If transcription failed, send a generic response
emit('error', {
'type': 'error',
'message': "Sorry, I couldn't understand what you said. Could you try again?"
})
# Clear buffer and reset silence detection
client['streaming_buffer'] = []
client['energy_window'].clear()
client['is_silence'] = False
client['last_active_time'] = time.time()
# If buffer gets too large without silence, process it anyway
elif len(client['streaming_buffer']) >= 30: # ~6 seconds of audio at 5 chunks/sec
full_audio = torch.cat(client['streaming_buffer'], dim=0)
# Process with WhisperX speech-to-text
transcribed_text = transcribe_audio(full_audio)
if transcribed_text:
client['context_segments'].append(
Segment(text=transcribed_text, speaker=speaker_id, audio=full_audio)
)
# Send the transcribed text to client
emit('transcription', {
'type': 'transcription',
'text': transcribed_text + " (processing continued speech...)"
})
client['streaming_buffer'] = []
except Exception as e:
import traceback
traceback.print_exc()
print(f"Error processing streaming audio: {str(e)}")
emit('error', {
'type': 'error',
'message': f"Error processing streaming audio: {str(e)}"
})
@socketio.on('stop_streaming')
def handle_stop_streaming(data):
client_id = request.sid
if client_id not in active_clients:
return
client = active_clients[client_id]
client['is_streaming'] = False
if client['streaming_buffer'] and len(client['streaming_buffer']) > 5:
# Process any remaining audio in the buffer
full_audio = torch.cat(client['streaming_buffer'], dim=0)
# Process with WhisperX speech-to-text
transcribed_text = transcribe_audio(full_audio)
if transcribed_text:
client['context_segments'].append(
Segment(text=transcribed_text, speaker=data.get("speaker", 0), audio=full_audio)
)
# Send the transcribed text to client
emit('transcription', {
'type': 'transcription',
'text': transcribed_text
})
client['streaming_buffer'] = []
emit('streaming_status', {
'type': 'streaming_status',
'status': 'stopped'
})
if __name__ == "__main__":
print(f"\n{'='*60}")
print(f"🔊 Sesame AI Voice Chat Server (Flask Implementation)")
print(f"{'='*60}")
print(f"📡 Server Information:")
print(f" - Local URL: http://localhost:5000")
print(f" - Network URL: http://<your-ip-address>:5000")
print(f" - WebSocket: ws://<your-ip-address>:5000/socket.io")
print(f"{'='*60}")
print(f"💡 To make this server public:")
print(f" 1. Ensure port 5000 is open in your firewall")
print(f" 2. Set up port forwarding on your router to port 5000")
print(f" 3. Or use a service like ngrok with: ngrok http 5000")
print(f"{'='*60}")
print(f"🌐 Device: {device.upper()}")
print(f"🧠 Models loaded: Sesame CSM + WhisperX ({asr_model.device})")
print(f"🔧 Serving from: {os.path.join(base_dir, 'index.html')}")
print(f"{'='*60}")
print(f"Ready to receive connections! Press Ctrl+C to stop the server.\n")
socketio.run(app, host="0.0.0.0", port=5000, debug=False)