Demo Update 6

This commit is contained in:
2025-03-30 00:24:26 -04:00
parent d83b078bc2
commit 6152e300c0

View File

@@ -1,9 +1,13 @@
import os
import base64
import json
import time
import math
import gc
import logging
import numpy as np
import torch
import torchaudio
import numpy as np
import whisperx
from io import BytesIO
from typing import List, Dict, Any, Optional
@@ -11,290 +15,314 @@ from flask import Flask, request, send_from_directory, Response
from flask_cors import CORS
from flask_socketio import SocketIO, emit, disconnect
from generator import load_csm_1b, Segment
import time
import gc
from collections import deque
from threading import Lock
# Add this at the top of your file, replacing your current CUDA setup
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger("sesame-server")
# CUDA setup with robust error handling
try:
# Handle CUDA issues
os.environ["CUDA_VISIBLE_DEVICES"] = "0" # Limit to first GPU only
# CUDA Environment Setup
def setup_cuda_environment():
"""Set up CUDA environment with proper error handling"""
# Search for CUDA libraries in common locations
cuda_lib_dirs = [
"/usr/local/cuda/lib64",
"/usr/lib/x86_64-linux-gnu",
"/usr/local/cuda/extras/CUPTI/lib64"
]
# Try enabling TF32 precision
try:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
except:
pass # Ignore if not supported
# Add directories to LD_LIBRARY_PATH if they exist
current_ld_path = os.environ.get('LD_LIBRARY_PATH', '')
for cuda_dir in cuda_lib_dirs:
if os.path.exists(cuda_dir) and cuda_dir not in current_ld_path:
if current_ld_path:
os.environ['LD_LIBRARY_PATH'] = f"{current_ld_path}:{cuda_dir}"
else:
os.environ['LD_LIBRARY_PATH'] = cuda_dir
current_ld_path = os.environ['LD_LIBRARY_PATH']
# Check if CUDA is available
if torch.cuda.is_available():
try:
# Test CUDA functionality
x = torch.rand(10, device="cuda")
y = x + x
del x, y
device = "cuda"
compute_type = "float16"
print("CUDA is fully functional")
except Exception as cuda_error:
print(f"CUDA is available but not working correctly: {str(cuda_error)}")
device = "cpu"
compute_type = "int8"
else:
device = "cpu"
compute_type = "int8"
except Exception as e:
print(f"Error setting up CUDA: {str(e)}")
logger.info(f"LD_LIBRARY_PATH set to: {os.environ.get('LD_LIBRARY_PATH', 'not set')}")
# Determine best compute device
device = "cpu"
compute_type = "int8"
print(f"Using device: {device} with compute type: {compute_type}")
try:
# Set CUDA preferences
os.environ["CUDA_VISIBLE_DEVICES"] = "0" # Limit to first GPU only
# Initialize the Sesame CSM model with robust error handling
try:
print(f"Loading Sesame CSM model on {device}...")
generator = load_csm_1b(device=device)
print("Sesame CSM model loaded successfully")
except Exception as model_error:
print(f"Error loading Sesame CSM on {device}: {str(model_error)}")
if device == "cuda":
# Try on CPU as fallback
# Try enabling TF32 precision if available
try:
print("Trying to load Sesame CSM on CPU instead...")
device = "cpu" # Update global device setting
generator = load_csm_1b(device="cpu")
print("Sesame CSM model loaded on CPU successfully")
except Exception as cpu_error:
print(f"Fatal error - could not load Sesame CSM model: {str(cpu_error)}")
raise RuntimeError("Failed to load speech synthesis model")
else:
# Already tried CPU and it failed
raise RuntimeError("Failed to load speech synthesis model on any device")
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
except Exception as e:
logger.warning(f"Could not set advanced CUDA options: {e}")
# Replace the WhisperX model loading section
# Initialize WhisperX for ASR with robust error handling
print("Loading WhisperX model...")
asr_model = None # Initialize to None first to avoid scope issues
try:
# Always start with the tiny model on CPU for stability
asr_model = whisperx.load_model("tiny", "cpu", compute_type="int8")
print("WhisperX 'tiny' model loaded on CPU successfully")
# If CPU works, try CUDA if available
if device == "cuda":
try:
print("Trying to load WhisperX on CUDA...")
cuda_model = whisperx.load_model("tiny", "cuda", compute_type="float16")
# Test the model to ensure it works
test_audio = torch.zeros(16000) # 1 second of silence at 16kHz
_ = cuda_model.transcribe(test_audio.numpy(), batch_size=1)
# If we get here, CUDA works
asr_model = cuda_model
print("WhisperX model moved to CUDA successfully")
# Try to upgrade to small model on CUDA
# Test if CUDA is functional
if torch.cuda.is_available():
try:
small_model = whisperx.load_model("small", "cuda", compute_type="float16")
# Test it
_ = small_model.transcribe(test_audio.numpy(), batch_size=1)
asr_model = small_model
print("WhisperX 'small' model loaded on CUDA successfully")
except Exception as upgrade_error:
print(f"Staying with 'tiny' model on CUDA: {str(upgrade_error)}")
except Exception as cuda_error:
print(f"CUDA loading failed, staying with CPU model: {str(cuda_error)}")
except Exception as e:
print(f"Error loading WhisperX model: {str(e)}")
# Create a minimal dummy model as last resort
class DummyModel:
def __init__(self):
self.device = "cpu"
def transcribe(self, *args, **kwargs):
return {"segments": [{"text": "Speech recognition currently unavailable."}]}
# Test basic CUDA operations
x = torch.rand(10, device="cuda")
y = x + x
del x, y
torch.cuda.empty_cache()
device = "cuda"
compute_type = "float16"
logger.info("CUDA is fully functional")
except Exception as e:
logger.warning(f"CUDA available but not working correctly: {e}")
device = "cpu"
else:
logger.info("CUDA is not available, using CPU")
except Exception as e:
logger.error(f"Error setting up computing environment: {e}")
asr_model = DummyModel()
print("WARNING: Using dummy transcription model - ASR functionality limited")
return device, compute_type
# Silence detection parameters
SILENCE_THRESHOLD = 0.01 # Adjust based on your audio normalization
SILENCE_DURATION_SEC = 1.0 # How long silence must persist
# Set up the compute environment
device, compute_type = setup_cuda_environment()
# Define the base directory
# Constants and Configuration
SILENCE_THRESHOLD = 0.01
SILENCE_DURATION_SEC = 0.75
MAX_BUFFER_SIZE = 30 # Maximum chunks to buffer before processing
CHUNK_SIZE_MS = 500 # Size of audio chunks when streaming responses
# Define the base directory and static files directory
base_dir = os.path.dirname(os.path.abspath(__file__))
static_dir = os.path.join(base_dir, "static")
os.makedirs(static_dir, exist_ok=True)
# Setup Flask
# Model Loading Functions
def load_speech_models():
"""Load all required speech models with fallbacks"""
# Load speech generation model (Sesame CSM)
try:
logger.info(f"Loading Sesame CSM model on {device}...")
generator = load_csm_1b(device=device)
logger.info("Sesame CSM model loaded successfully")
except Exception as e:
logger.error(f"Error loading Sesame CSM on {device}: {e}")
if device == "cuda":
try:
logger.info("Trying to load Sesame CSM on CPU instead...")
generator = load_csm_1b(device="cpu")
logger.info("Sesame CSM model loaded on CPU successfully")
except Exception as cpu_error:
logger.critical(f"Failed to load speech synthesis model: {cpu_error}")
raise RuntimeError("Failed to load speech synthesis model")
else:
raise RuntimeError("Failed to load speech synthesis model on any device")
# Load ASR model (WhisperX)
try:
logger.info("Loading WhisperX model...")
# Start with the tiny model on CPU for reliable initialization
asr_model = whisperx.load_model("tiny", "cpu", compute_type="int8")
logger.info("WhisperX 'tiny' model loaded on CPU successfully")
# Try upgrading to GPU if available
if device == "cuda":
try:
logger.info("Trying to load WhisperX on CUDA...")
# Test with a tiny model first
test_audio = torch.zeros(16000) # 1 second of silence
cuda_model = whisperx.load_model("tiny", "cuda", compute_type="float16")
# Test the model with real inference
_ = cuda_model.transcribe(test_audio.numpy(), batch_size=1)
asr_model = cuda_model
logger.info("WhisperX model running on CUDA successfully")
# Try to upgrade to small model
try:
small_model = whisperx.load_model("small", "cuda", compute_type="float16")
_ = small_model.transcribe(test_audio.numpy(), batch_size=1)
asr_model = small_model
logger.info("WhisperX 'small' model loaded on CUDA successfully")
except Exception as e:
logger.warning(f"Staying with 'tiny' model on CUDA: {e}")
except Exception as e:
logger.warning(f"CUDA loading failed, staying with CPU model: {e}")
except Exception as e:
logger.error(f"Error loading WhisperX model: {e}")
# Create a minimal dummy model as last resort
class DummyModel:
def __init__(self):
self.device = "cpu"
def transcribe(self, *args, **kwargs):
return {"segments": [{"text": "Speech recognition currently unavailable."}]}
asr_model = DummyModel()
logger.warning("Using dummy transcription model - ASR functionality limited")
return generator, asr_model
# Load speech models
generator, asr_model = load_speech_models()
# Set up Flask and Socket.IO
app = Flask(__name__)
CORS(app)
socketio = SocketIO(app, cors_allowed_origins="*", async_mode='eventlet')
# Socket connection management
thread = None
thread_lock = Lock()
active_clients = {} # Map client_id to client context
# Helper function to convert audio data
# Audio Utility Functions
def decode_audio_data(audio_data: str) -> torch.Tensor:
"""Decode base64 audio data to a torch tensor with improved error handling"""
try:
# Skip empty audio data
if not audio_data or len(audio_data) < 100:
print("Empty or too short audio data received")
logger.warning("Empty or too short audio data received")
return torch.zeros(generator.sample_rate // 2) # 0.5 seconds of silence
# Extract the actual base64 content
if ',' in audio_data:
# Handle data URL format (data:audio/wav;base64,...)
audio_data = audio_data.split(',')[1]
# Decode base64 audio data
try:
binary_data = base64.b64decode(audio_data)
print(f"Decoded base64 data: {len(binary_data)} bytes")
logger.debug(f"Decoded base64 data: {len(binary_data)} bytes")
# Check if we have enough data for a valid WAV
if len(binary_data) < 44: # WAV header is 44 bytes
print("Data too small to be a valid WAV file")
logger.warning("Data too small to be a valid WAV file")
return torch.zeros(generator.sample_rate // 2)
except Exception as e:
print(f"Base64 decoding error: {str(e)}")
logger.error(f"Base64 decoding error: {e}")
return torch.zeros(generator.sample_rate // 2)
# Save for debugging
debug_path = os.path.join(base_dir, "debug_incoming.wav")
with open(debug_path, 'wb') as f:
f.write(binary_data)
print(f"Saved debug file: {debug_path}")
# Multiple approaches to handle audio data
audio_tensor = None
sample_rate = None
# Approach 1: Load directly with torchaudio
# Approach 1: Direct loading with torchaudio
try:
with BytesIO(binary_data) as temp_file:
temp_file.seek(0) # Ensure we're at the start of the buffer
temp_file.seek(0)
audio_tensor, sample_rate = torchaudio.load(temp_file, format="wav")
print(f"Direct loading success: shape={audio_tensor.shape}, rate={sample_rate}Hz")
logger.debug(f"Loaded audio: shape={audio_tensor.shape}, rate={sample_rate}Hz")
# Check if audio is valid
# Validate tensor
if audio_tensor.numel() == 0 or torch.isnan(audio_tensor).any():
raise ValueError("Empty or invalid audio tensor detected")
raise ValueError("Invalid audio tensor")
except Exception as e:
print(f"Direct loading failed: {str(e)}")
logger.warning(f"Direct loading failed: {e}")
# Approach 2: Try to fix/normalize the WAV data
# Approach 2: Using wave module and numpy
try:
# Sometimes WAV headers can be malformed, attempt to fix
temp_path = os.path.join(base_dir, "temp_fixing.wav")
temp_path = os.path.join(base_dir, f"temp_{time.time()}.wav")
with open(temp_path, 'wb') as f:
f.write(binary_data)
# Use a simpler numpy approach as backup
import numpy as np
import wave
with wave.open(temp_path, 'rb') as wf:
n_channels = wf.getnchannels()
sample_width = wf.getsampwidth()
sample_rate = wf.getframerate()
n_frames = wf.getnframes()
frames = wf.readframes(n_frames)
try:
with wave.open(temp_path, 'rb') as wf:
n_channels = wf.getnchannels()
sample_width = wf.getsampwidth()
sample_rate = wf.getframerate()
n_frames = wf.getnframes()
# Convert to numpy array
if sample_width == 2: # 16-bit audio
data = np.frombuffer(frames, dtype=np.int16).astype(np.float32) / 32768.0
elif sample_width == 1: # 8-bit audio
data = np.frombuffer(frames, dtype=np.uint8).astype(np.float32) / 128.0 - 1.0
else:
raise ValueError(f"Unsupported sample width: {sample_width}")
# Read the frames
frames = wf.readframes(n_frames)
print(f"Wave reading: channels={n_channels}, rate={sample_rate}Hz, frames={n_frames}")
# Convert to mono if needed
if n_channels > 1:
data = data.reshape(-1, n_channels)
data = data.mean(axis=1)
# Convert to numpy and then to torch
if sample_width == 2: # 16-bit audio
data = np.frombuffer(frames, dtype=np.int16).astype(np.float32) / 32768.0
elif sample_width == 1: # 8-bit audio
data = np.frombuffer(frames, dtype=np.uint8).astype(np.float32) / 128.0 - 1.0
else:
raise ValueError(f"Unsupported sample width: {sample_width}")
# Convert to torch tensor
audio_tensor = torch.from_numpy(data)
logger.info(f"Loaded audio using wave: shape={audio_tensor.shape}")
# Convert to mono if needed
if n_channels > 1:
data = data.reshape(-1, n_channels)
data = data.mean(axis=1)
# Convert to torch tensor
audio_tensor = torch.from_numpy(data)
print(f"Successfully converted with numpy: shape={audio_tensor.shape}")
except Exception as wave_error:
print(f"Wave processing failed: {str(wave_error)}")
# Try with torchaudio as last resort
audio_tensor, sample_rate = torchaudio.load(temp_path, format="wav")
# Clean up
# Clean up temp file
if os.path.exists(temp_path):
os.remove(temp_path)
except Exception as e2:
print(f"All WAV loading methods failed: {str(e2)}")
print("Returning silence as fallback")
logger.error(f"All audio loading methods failed: {e2}")
return torch.zeros(generator.sample_rate // 2)
# Ensure audio is the right shape (mono)
# Format corrections
if audio_tensor is None:
return torch.zeros(generator.sample_rate // 2)
# Ensure audio is mono
if len(audio_tensor.shape) > 1 and audio_tensor.shape[0] > 1:
audio_tensor = torch.mean(audio_tensor, dim=0)
# Ensure we have a 1D tensor
# Ensure 1D tensor
audio_tensor = audio_tensor.squeeze()
# Resample if needed
if sample_rate != generator.sample_rate:
try:
print(f"Resampling from {sample_rate}Hz to {generator.sample_rate}Hz")
logger.debug(f"Resampling from {sample_rate}Hz to {generator.sample_rate}Hz")
resampler = torchaudio.transforms.Resample(
orig_freq=sample_rate,
new_freq=generator.sample_rate
)
audio_tensor = resampler(audio_tensor)
except Exception as e:
print(f"Resampling error: {str(e)}")
# If resampling fails, just return the original audio
# The model can often handle different sample rates
logger.warning(f"Resampling error: {e}")
# Normalize audio to avoid issues
if torch.abs(audio_tensor).max() > 0:
audio_tensor = audio_tensor / torch.abs(audio_tensor).max()
print(f"Final audio tensor: shape={audio_tensor.shape}, min={audio_tensor.min().item():.4f}, max={audio_tensor.max().item():.4f}")
return audio_tensor
except Exception as e:
print(f"Unhandled error in decode_audio_data: {str(e)}")
# Return a small silent audio segment as fallback
return torch.zeros(generator.sample_rate // 2) # 0.5 seconds of silence
logger.error(f"Unhandled error in decode_audio_data: {e}")
return torch.zeros(generator.sample_rate // 2)
def encode_audio_data(audio_tensor: torch.Tensor) -> str:
"""Encode torch tensor audio to base64 string"""
buf = BytesIO()
torchaudio.save(buf, audio_tensor.unsqueeze(0).cpu(), generator.sample_rate, format="wav")
buf.seek(0)
audio_base64 = base64.b64encode(buf.read()).decode('utf-8')
return f"data:audio/wav;base64,{audio_base64}"
try:
buf = BytesIO()
torchaudio.save(buf, audio_tensor.unsqueeze(0).cpu(), generator.sample_rate, format="wav")
buf.seek(0)
audio_base64 = base64.b64encode(buf.read()).decode('utf-8')
return f"data:audio/wav;base64,{audio_base64}"
except Exception as e:
logger.error(f"Error encoding audio: {e}")
# Return a minimal silent audio file
silence = torch.zeros(generator.sample_rate // 2).unsqueeze(0)
buf = BytesIO()
torchaudio.save(buf, silence, generator.sample_rate, format="wav")
buf.seek(0)
return f"data:audio/wav;base64,{base64.b64encode(buf.read()).decode('utf-8')}"
def transcribe_audio(audio_tensor: torch.Tensor) -> str:
"""Transcribe audio using WhisperX with robust error handling"""
global asr_model # Declare global at the beginning of the function
global asr_model
try:
# Save the tensor to a temporary file
temp_path = os.path.join(base_dir, "temp_audio.wav")
temp_path = os.path.join(base_dir, f"temp_audio_{time.time()}.wav")
torchaudio.save(temp_path, audio_tensor.unsqueeze(0).cpu(), generator.sample_rate)
print(f"Transcribing audio file: {temp_path} (size: {os.path.getsize(temp_path)} bytes)")
logger.info(f"Transcribing audio file: {os.path.getsize(temp_path)} bytes")
# Load the audio file using whisperx's function
# Load the audio for WhisperX
try:
audio = whisperx.load_audio(temp_path)
except Exception as audio_load_error:
print(f"WhisperX load_audio failed: {str(audio_load_error)}")
except Exception as e:
logger.warning(f"WhisperX load_audio failed: {e}")
# Fall back to manual loading
import soundfile as sf
audio, sr = sf.read(temp_path)
@@ -302,59 +330,55 @@ def transcribe_audio(audio_tensor: torch.Tensor) -> str:
from scipy import signal
audio = signal.resample(audio, int(len(audio) * 16000 / sr))
# Transcribe with error handling for CUDA issues
# Transcribe with error handling
try:
# Try with original device
result = asr_model.transcribe(audio, batch_size=8)
except RuntimeError as cuda_error:
if "CUDA" in str(cuda_error) or "libcudnn" in str(cuda_error):
print(f"CUDA error in transcription, falling back to CPU: {str(cuda_error)}")
# Try to load a CPU model as fallback
result = asr_model.transcribe(audio, batch_size=4)
except RuntimeError as e:
if "CUDA" in str(e) or "libcudnn" in str(e):
logger.warning(f"CUDA error in transcription, falling back to CPU: {e}")
try:
# Move model to CPU and try again
asr_model = whisperx.load_model("tiny", "cpu", compute_type="int8")
result = asr_model.transcribe(audio, batch_size=1)
except Exception as e:
print(f"CPU fallback also failed: {str(e)}")
# Try CPU model
cpu_model = whisperx.load_model("tiny", "cpu", compute_type="int8")
result = cpu_model.transcribe(audio, batch_size=1)
# Update the global model if the original one is broken
asr_model = cpu_model
except Exception as cpu_e:
logger.error(f"CPU fallback failed: {cpu_e}")
return "I'm having trouble processing audio right now."
else:
# Re-raise if it's not a CUDA error
raise
finally:
# Clean up
if os.path.exists(temp_path):
os.remove(temp_path)
# Clean up
# Extract text from segments
if result["segments"] and len(result["segments"]) > 0:
transcription = " ".join([segment["text"] for segment in result["segments"]])
logger.info(f"Transcription: '{transcription.strip()}'")
return transcription.strip()
return ""
except Exception as e:
logger.error(f"Error in transcription: {e}")
if os.path.exists(temp_path):
os.remove(temp_path)
# Get the transcription text
if result["segments"] and len(result["segments"]) > 0:
# Combine all segments
transcription = " ".join([segment["text"] for segment in result["segments"]])
print(f"Transcription successful: '{transcription.strip()}'")
return transcription.strip()
else:
print("Transcription returned no segments")
return ""
except Exception as e:
print(f"Error in transcription: {str(e)}")
import traceback
traceback.print_exc()
if os.path.exists("temp_audio.wav"):
os.remove("temp_audio.wav")
return "I heard something but couldn't understand it."
def generate_response(text: str, conversation_history: List[Segment]) -> str:
"""Generate a contextual response based on the transcribed text"""
# Simple response logic - can be replaced with a more sophisticated LLM in the future
# Simple response logic - can be replaced with a more sophisticated LLM
responses = {
"hello": "Hello there! How are you doing today?",
"hello": "Hello there! How can I help you today?",
"hi": "Hi there! What can I do for you?",
"how are you": "I'm doing well, thanks for asking! How about you?",
"what is your name": "I'm Sesame, your voice assistant. How can I help you?",
"who are you": "I'm Sesame, an AI voice assistant. I'm here to chat with you!",
"bye": "Goodbye! It was nice chatting with you.",
"thank you": "You're welcome! Is there anything else I can help with?",
"weather": "I don't have real-time weather data, but I hope it's nice where you are!",
"help": "I can chat with you using natural voice. Just speak normally and I'll respond.",
"what can you do": "I can have a conversation with you, answer questions, and provide assistance with various topics.",
}
text_lower = text.lower()
@@ -372,7 +396,7 @@ def generate_response(text: str, conversation_history: List[Segment]) -> str:
else:
return f"I understand you said '{text}'. That's interesting! Can you tell me more about that?"
# Flask routes for serving static content
# Flask Routes
@app.route('/')
def index():
return send_from_directory(base_dir, 'index.html')
@@ -391,11 +415,11 @@ def voice_chat_js():
def serve_static(path):
return send_from_directory(static_dir, path)
# Socket.IO event handlers
# Socket.IO Event Handlers
@socketio.on('connect')
def handle_connect():
client_id = request.sid
print(f"Client connected: {client_id}")
logger.info(f"Client connected: {client_id}")
# Initialize client context
active_clients[client_id] = {
@@ -414,7 +438,7 @@ def handle_disconnect():
client_id = request.sid
if client_id in active_clients:
del active_clients[client_id]
print(f"Client disconnected: {client_id}")
logger.info(f"Client disconnected: {client_id}")
@socketio.on('generate')
def handle_generate(data):
@@ -427,7 +451,7 @@ def handle_generate(data):
text = data.get('text', '')
speaker_id = data.get('speaker', 0)
print(f"Generating audio for: '{text}' with speaker {speaker_id}")
logger.info(f"Generating audio for: '{text}' with speaker {speaker_id}")
# Generate audio response
audio_tensor = generator.generate(
@@ -446,11 +470,12 @@ def handle_generate(data):
audio_base64 = encode_audio_data(audio_tensor)
emit('audio_response', {
'type': 'audio_response',
'audio': audio_base64
'audio': audio_base64,
'text': text
})
except Exception as e:
print(f"Error generating audio: {str(e)}")
logger.error(f"Error generating audio: {e}")
emit('error', {
'type': 'error',
'message': f"Error generating audio: {str(e)}"
@@ -482,7 +507,7 @@ def handle_add_to_context(data):
})
except Exception as e:
print(f"Error adding to context: {str(e)}")
logger.error(f"Error adding to context: {e}")
emit('error', {
'type': 'error',
'message': f"Error processing audio: {str(e)}"
@@ -512,6 +537,11 @@ def handle_stream_audio(data):
speaker_id = data.get('speaker', 0)
audio_data = data.get('audio', '')
# Skip if no audio data (might be just a connection test)
if not audio_data:
logger.debug("Empty audio data received, ignoring")
return
# Convert received audio to tensor
audio_chunk = decode_audio_data(audio_data)
@@ -522,7 +552,7 @@ def handle_stream_audio(data):
client['energy_window'].clear()
client['is_silence'] = False
client['last_active_time'] = time.time()
print(f"[{client_id}] Streaming started with speaker ID: {speaker_id}")
logger.info(f"[{client_id[:8]}] Streaming started with speaker ID: {speaker_id}")
emit('streaming_status', {
'type': 'streaming_status',
'status': 'started'
@@ -553,52 +583,74 @@ def handle_stream_audio(data):
if client['is_silence'] and silence_elapsed >= SILENCE_DURATION_SEC and len(client['streaming_buffer']) > 0:
# User has stopped talking - process the collected audio
print(f"[{client_id}] Processing audio after {silence_elapsed:.2f}s of silence")
logger.info(f"[{client_id[:8]}] Processing audio after {silence_elapsed:.2f}s of silence")
process_complete_utterance(client_id, client, speaker_id)
full_audio = torch.cat(client['streaming_buffer'], dim=0)
# If buffer gets too large without silence, process it anyway
elif len(client['streaming_buffer']) >= MAX_BUFFER_SIZE:
logger.info(f"[{client_id[:8]}] Processing long audio segment without silence")
process_complete_utterance(client_id, client, speaker_id, is_incomplete=True)
# Process with WhisperX speech-to-text
print(f"[{client_id}] Starting transcription with WhisperX...")
transcribed_text = transcribe_audio(full_audio)
# Keep half of the buffer for context (sliding window approach)
half_point = len(client['streaming_buffer']) // 2
client['streaming_buffer'] = client['streaming_buffer'][half_point:]
# Log the transcription
print(f"[{client_id}] Transcribed text: '{transcribed_text}'")
except Exception as e:
import traceback
traceback.print_exc()
logger.error(f"Error processing streaming audio: {e}")
emit('error', {
'type': 'error',
'message': f"Error processing streaming audio: {str(e)}"
})
# Handle the transcription result
if transcribed_text:
# Add user message to context
user_segment = Segment(text=transcribed_text, speaker=speaker_id, audio=full_audio)
client['context_segments'].append(user_segment)
def process_complete_utterance(client_id, client, speaker_id, is_incomplete=False):
"""Process a complete utterance (after silence or buffer limit)"""
try:
# Combine audio chunks
full_audio = torch.cat(client['streaming_buffer'], dim=0)
# Send the transcribed text to client
emit('transcription', {
'type': 'transcription',
'text': transcribed_text
})
# Process with speech-to-text
logger.info(f"[{client_id[:8]}] Starting transcription...")
transcribed_text = transcribe_audio(full_audio)
# Add suffix for incomplete utterances
if is_incomplete:
transcribed_text += " (processing continued speech...)"
# Log the transcription
logger.info(f"[{client_id[:8]}] Transcribed: '{transcribed_text}'")
# Handle the transcription result
if transcribed_text:
# Add user message to context
user_segment = Segment(text=transcribed_text, speaker=speaker_id, audio=full_audio)
client['context_segments'].append(user_segment)
# Send the transcribed text to client
emit('transcription', {
'type': 'transcription',
'text': transcribed_text
}, room=client_id)
# Only generate a response if this is a complete utterance
if not is_incomplete:
# Generate a contextual response
response_text = generate_response(transcribed_text, client['context_segments'])
print(f"[{client_id}] Generating audio response: '{response_text}'")
logger.info(f"[{client_id[:8]}] Generating response: '{response_text}'")
# Let the client know we're processing
emit('processing_status', {
'type': 'processing_status',
'status': 'generating_audio',
'message': 'Generating audio response...'
})
}, room=client_id)
# Generate audio for the response
try:
# Use a different speaker than the user
ai_speaker_id = 1 if speaker_id == 0 else 0
# Start audio generation with streaming (chunk by chunk)
audio_chunks = []
# This version tries to stream the audio generation in smaller chunks
# Note: CSM model doesn't natively support incremental generation,
# so we're simulating it here for a more responsive UI experience
# Generate the full response
audio_tensor = generator.generate(
text=response_text,
@@ -621,60 +673,37 @@ def handle_stream_audio(data):
'type': 'audio_response',
'text': response_text,
'audio': audio_base64
})
}, room=client_id)
print(f"[{client_id}] Audio response sent: {len(audio_base64)} bytes")
logger.info(f"[{client_id[:8]}] Audio response sent")
except Exception as gen_error:
print(f"Error generating audio response: {str(gen_error)}")
except Exception as e:
logger.error(f"Error generating audio response: {e}")
emit('error', {
'type': 'error',
'message': "Sorry, there was an error generating the audio response."
})
else:
# If transcription failed, send a generic response
emit('error', {
'type': 'error',
'message': "Sorry, I couldn't understand what you said. Could you try again?"
})
}, room=client_id)
else:
# If transcription failed, send a notification
emit('error', {
'type': 'error',
'message': "Sorry, I couldn't understand what you said. Could you try again?"
}, room=client_id)
# Clear buffer and reset silence detection
# Only clear buffer for complete utterances
if not is_incomplete:
# Reset state
client['streaming_buffer'] = []
client['energy_window'].clear()
client['is_silence'] = False
client['last_active_time'] = time.time()
# If buffer gets too large without silence, process it anyway
elif len(client['streaming_buffer']) >= 30: # ~6 seconds of audio at 5 chunks/sec
print(f"[{client_id}] Processing long audio segment without silence")
full_audio = torch.cat(client['streaming_buffer'], dim=0)
# Process with WhisperX speech-to-text
transcribed_text = transcribe_audio(full_audio)
if transcribed_text:
client['context_segments'].append(
Segment(text=transcribed_text, speaker=speaker_id, audio=full_audio)
)
# Send the transcribed text to client
emit('transcription', {
'type': 'transcription',
'text': transcribed_text + " (processing continued speech...)"
})
# Keep half of the buffer for context (sliding window approach)
half_point = len(client['streaming_buffer']) // 2
client['streaming_buffer'] = client['streaming_buffer'][half_point:]
except Exception as e:
import traceback
traceback.print_exc()
print(f"Error processing streaming audio: {str(e)}")
logger.error(f"Error processing utterance: {e}")
emit('error', {
'type': 'error',
'message': f"Error processing streaming audio: {str(e)}"
})
'message': f"Error processing audio: {str(e)}"
}, room=client_id)
@socketio.on('stop_streaming')
def handle_stop_streaming(data):
@@ -687,21 +716,8 @@ def handle_stop_streaming(data):
if client['streaming_buffer'] and len(client['streaming_buffer']) > 5:
# Process any remaining audio in the buffer
full_audio = torch.cat(client['streaming_buffer'], dim=0)
# Process with WhisperX speech-to-text
transcribed_text = transcribe_audio(full_audio)
if transcribed_text:
client['context_segments'].append(
Segment(text=transcribed_text, speaker=data.get("speaker", 0), audio=full_audio)
)
# Send the transcribed text to client
emit('transcription', {
'type': 'transcription',
'text': transcribed_text
})
logger.info(f"[{client_id[:8]}] Processing final audio buffer on stop")
process_complete_utterance(client_id, client, data.get("speaker", 0))
client['streaming_buffer'] = []
emit('streaming_status', {
@@ -709,18 +725,18 @@ def handle_stop_streaming(data):
'status': 'stopped'
})
def stream_audio_to_client(client_id, audio_tensor, text, speaker_id, chunk_size_ms=500):
def stream_audio_to_client(client_id, audio_tensor, text, speaker_id, chunk_size_ms=CHUNK_SIZE_MS):
"""Stream audio to client in chunks to simulate real-time generation"""
try:
if client_id not in active_clients:
print(f"Client {client_id} not found for streaming")
logger.warning(f"Client {client_id} not found for streaming")
return
# Calculate chunk size in samples
chunk_size = int(generator.sample_rate * chunk_size_ms / 1000)
total_chunks = math.ceil(audio_tensor.size(0) / chunk_size)
print(f"Streaming audio in {total_chunks} chunks of {chunk_size_ms}ms each")
logger.info(f"Streaming audio in {total_chunks} chunks of {chunk_size_ms}ms each")
# Send initial response with text but no audio yet
socketio.emit('audio_response_start', {
@@ -758,29 +774,24 @@ def stream_audio_to_client(client_id, audio_tensor, text, speaker_id, chunk_size
'text': text
}, room=client_id)
print(f"Audio streaming complete: {total_chunks} chunks sent")
logger.info(f"Audio streaming complete: {total_chunks} chunks sent")
except Exception as e:
print(f"Error streaming audio to client: {str(e)}")
logger.error(f"Error streaming audio to client: {e}")
import traceback
traceback.print_exc()
# Main server start
if __name__ == "__main__":
print(f"\n{'='*60}")
print(f"🔊 Sesame AI Voice Chat Server (Flask Implementation)")
print(f"🔊 Sesame AI Voice Chat Server")
print(f"{'='*60}")
print(f"📡 Server Information:")
print(f" - Local URL: http://localhost:5000")
print(f" - Network URL: http://<your-ip-address>:5000")
print(f" - WebSocket: ws://<your-ip-address>:5000/socket.io")
print(f"{'='*60}")
print(f"💡 To make this server public:")
print(f" 1. Ensure port 5000 is open in your firewall")
print(f" 2. Set up port forwarding on your router to port 5000")
print(f" 3. Or use a service like ngrok with: ngrok http 5000")
print(f"{'='*60}")
print(f"🌐 Device: {device.upper()}")
print(f"🧠 Models loaded: Sesame CSM + WhisperX ({asr_model.device})")
print(f"🧠 Models: Sesame CSM + WhisperX ASR")
print(f"🔧 Serving from: {os.path.join(base_dir, 'index.html')}")
print(f"{'='*60}")
print(f"Ready to receive connections! Press Ctrl+C to stop the server.\n")