Demo Update 7

This commit is contained in:
2025-03-30 00:33:14 -04:00
parent 6152e300c0
commit 8592257cdc

View File

@@ -8,7 +8,6 @@ import logging
import numpy as np import numpy as np
import torch import torch
import torchaudio import torchaudio
import whisperx
from io import BytesIO from io import BytesIO
from typing import List, Dict, Any, Optional from typing import List, Dict, Any, Optional
from flask import Flask, request, send_from_directory, Response from flask import Flask, request, send_from_directory, Response
@@ -25,68 +24,24 @@ logging.basicConfig(
) )
logger = logging.getLogger("sesame-server") logger = logging.getLogger("sesame-server")
# CUDA Environment Setup
def setup_cuda_environment():
"""Set up CUDA environment with proper error handling"""
# Search for CUDA libraries in common locations
cuda_lib_dirs = [
"/usr/local/cuda/lib64",
"/usr/lib/x86_64-linux-gnu",
"/usr/local/cuda/extras/CUPTI/lib64"
]
# Add directories to LD_LIBRARY_PATH if they exist
current_ld_path = os.environ.get('LD_LIBRARY_PATH', '')
for cuda_dir in cuda_lib_dirs:
if os.path.exists(cuda_dir) and cuda_dir not in current_ld_path:
if current_ld_path:
os.environ['LD_LIBRARY_PATH'] = f"{current_ld_path}:{cuda_dir}"
else:
os.environ['LD_LIBRARY_PATH'] = cuda_dir
current_ld_path = os.environ['LD_LIBRARY_PATH']
logger.info(f"LD_LIBRARY_PATH set to: {os.environ.get('LD_LIBRARY_PATH', 'not set')}")
# Determine best compute device # Determine best compute device
device = "cpu" if torch.backends.mps.is_available():
compute_type = "int8" device = "mps"
elif torch.cuda.is_available():
try:
# Set CUDA preferences
os.environ["CUDA_VISIBLE_DEVICES"] = "0" # Limit to first GPU only
# Try enabling TF32 precision if available
try: try:
# Test CUDA functionality
torch.rand(10, device="cuda")
torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
except Exception as e:
logger.warning(f"Could not set advanced CUDA options: {e}")
# Test if CUDA is functional
if torch.cuda.is_available():
try:
# Test basic CUDA operations
x = torch.rand(10, device="cuda")
y = x + x
del x, y
torch.cuda.empty_cache()
device = "cuda" device = "cuda"
compute_type = "float16"
logger.info("CUDA is fully functional") logger.info("CUDA is fully functional")
except Exception as e: except Exception as e:
logger.warning(f"CUDA available but not working correctly: {e}") logger.warning(f"CUDA available but not working correctly: {e}")
device = "cpu" device = "cpu"
else: else:
logger.info("CUDA is not available, using CPU") device = "cpu"
except Exception as e: logger.info("Using CPU")
logger.error(f"Error setting up computing environment: {e}")
return device, compute_type
# Set up the compute environment
device, compute_type = setup_cuda_environment()
# Constants and Configuration # Constants and Configuration
SILENCE_THRESHOLD = 0.01 SILENCE_THRESHOLD = 0.01
@@ -99,9 +54,37 @@ base_dir = os.path.dirname(os.path.abspath(__file__))
static_dir = os.path.join(base_dir, "static") static_dir = os.path.join(base_dir, "static")
os.makedirs(static_dir, exist_ok=True) os.makedirs(static_dir, exist_ok=True)
# Define a simple energy-based speech detector
class SpeechDetector:
def __init__(self):
self.min_speech_energy = 0.01
self.speech_window = 0.2 # seconds
def detect_speech(self, audio_tensor, sample_rate):
# Calculate frame size based on window size
frame_size = int(sample_rate * self.speech_window)
# If audio is shorter than frame size, use the entire audio
if audio_tensor.shape[0] < frame_size:
frames = [audio_tensor]
else:
# Split audio into frames
frames = [audio_tensor[i:i+frame_size] for i in range(0, len(audio_tensor), frame_size)]
# Calculate energy per frame
energies = [torch.mean(frame**2).item() for frame in frames]
# Determine if there's speech based on energy threshold
has_speech = any(e > self.min_speech_energy for e in energies)
return has_speech
speech_detector = SpeechDetector()
logger.info("Initialized simple speech detector")
# Model Loading Functions # Model Loading Functions
def load_speech_models(): def load_speech_models():
"""Load all required speech models with fallbacks""" """Load speech generation model"""
# Load speech generation model (Sesame CSM) # Load speech generation model (Sesame CSM)
try: try:
logger.info(f"Loading Sesame CSM model on {device}...") logger.info(f"Loading Sesame CSM model on {device}...")
@@ -120,52 +103,10 @@ def load_speech_models():
else: else:
raise RuntimeError("Failed to load speech synthesis model on any device") raise RuntimeError("Failed to load speech synthesis model on any device")
# Load ASR model (WhisperX) return generator
try:
logger.info("Loading WhisperX model...")
# Start with the tiny model on CPU for reliable initialization
asr_model = whisperx.load_model("tiny", "cpu", compute_type="int8")
logger.info("WhisperX 'tiny' model loaded on CPU successfully")
# Try upgrading to GPU if available # Load speech model
if device == "cuda": generator = load_speech_models()
try:
logger.info("Trying to load WhisperX on CUDA...")
# Test with a tiny model first
test_audio = torch.zeros(16000) # 1 second of silence
cuda_model = whisperx.load_model("tiny", "cuda", compute_type="float16")
# Test the model with real inference
_ = cuda_model.transcribe(test_audio.numpy(), batch_size=1)
asr_model = cuda_model
logger.info("WhisperX model running on CUDA successfully")
# Try to upgrade to small model
try:
small_model = whisperx.load_model("small", "cuda", compute_type="float16")
_ = small_model.transcribe(test_audio.numpy(), batch_size=1)
asr_model = small_model
logger.info("WhisperX 'small' model loaded on CUDA successfully")
except Exception as e:
logger.warning(f"Staying with 'tiny' model on CUDA: {e}")
except Exception as e:
logger.warning(f"CUDA loading failed, staying with CPU model: {e}")
except Exception as e:
logger.error(f"Error loading WhisperX model: {e}")
# Create a minimal dummy model as last resort
class DummyModel:
def __init__(self):
self.device = "cpu"
def transcribe(self, *args, **kwargs):
return {"segments": [{"text": "Speech recognition currently unavailable."}]}
asr_model = DummyModel()
logger.warning("Using dummy transcription model - ASR functionality limited")
return generator, asr_model
# Load speech models
generator, asr_model = load_speech_models()
# Set up Flask and Socket.IO # Set up Flask and Socket.IO
app = Flask(__name__) app = Flask(__name__)
@@ -307,63 +248,23 @@ def encode_audio_data(audio_tensor: torch.Tensor) -> str:
buf.seek(0) buf.seek(0)
return f"data:audio/wav;base64,{base64.b64encode(buf.read()).decode('utf-8')}" return f"data:audio/wav;base64,{base64.b64encode(buf.read()).decode('utf-8')}"
def transcribe_audio(audio_tensor: torch.Tensor) -> str: def process_speech(audio_tensor: torch.Tensor, client_id: str) -> str:
"""Transcribe audio using WhisperX with robust error handling""" """Process speech and return a simple response"""
global asr_model # In this simplified version, we'll just check if there's sound
# and provide basic responses instead of doing actual speech recognition
try: if speech_detector and speech_detector.detect_speech(audio_tensor, generator.sample_rate):
# Save the tensor to a temporary file # Generate a response based on audio energy
temp_path = os.path.join(base_dir, f"temp_audio_{time.time()}.wav") energy = torch.mean(torch.abs(audio_tensor)).item()
torchaudio.save(temp_path, audio_tensor.unsqueeze(0).cpu(), generator.sample_rate)
logger.info(f"Transcribing audio file: {os.path.getsize(temp_path)} bytes") if energy > 0.1: # Louder speech
return "I heard you speaking clearly. How can I help you today?"
# Load the audio for WhisperX elif energy > 0.05: # Moderate speech
try: return "I heard you say something. Could you please repeat that?"
audio = whisperx.load_audio(temp_path) else: # Soft speech
except Exception as e: return "I detected some speech, but it was quite soft. Could you speak up a bit?"
logger.warning(f"WhisperX load_audio failed: {e}")
# Fall back to manual loading
import soundfile as sf
audio, sr = sf.read(temp_path)
if sr != 16000: # WhisperX expects 16kHz audio
from scipy import signal
audio = signal.resample(audio, int(len(audio) * 16000 / sr))
# Transcribe with error handling
try:
result = asr_model.transcribe(audio, batch_size=4)
except RuntimeError as e:
if "CUDA" in str(e) or "libcudnn" in str(e):
logger.warning(f"CUDA error in transcription, falling back to CPU: {e}")
try:
# Try CPU model
cpu_model = whisperx.load_model("tiny", "cpu", compute_type="int8")
result = cpu_model.transcribe(audio, batch_size=1)
# Update the global model if the original one is broken
asr_model = cpu_model
except Exception as cpu_e:
logger.error(f"CPU fallback failed: {cpu_e}")
return "I'm having trouble processing audio right now."
else: else:
raise return "I didn't detect any speech. Could you please try again?"
finally:
# Clean up
if os.path.exists(temp_path):
os.remove(temp_path)
# Extract text from segments
if result["segments"] and len(result["segments"]) > 0:
transcription = " ".join([segment["text"] for segment in result["segments"]])
logger.info(f"Transcription: '{transcription.strip()}'")
return transcription.strip()
return ""
except Exception as e:
logger.error(f"Error in transcription: {e}")
if os.path.exists(temp_path):
os.remove(temp_path)
return "I heard something but couldn't understand it."
def generate_response(text: str, conversation_history: List[Segment]) -> str: def generate_response(text: str, conversation_history: List[Segment]) -> str:
"""Generate a contextual response based on the transcribed text""" """Generate a contextual response based on the transcribed text"""
@@ -394,7 +295,7 @@ def generate_response(text: str, conversation_history: List[Segment]) -> str:
elif len(text) < 10: elif len(text) < 10:
return "Thanks for your message. Could you elaborate a bit more?" return "Thanks for your message. Could you elaborate a bit more?"
else: else:
return f"I understand you said '{text}'. That's interesting! Can you tell me more about that?" return f"I heard you speaking. That's interesting! Can you tell me more about that?"
# Flask Routes # Flask Routes
@app.route('/') @app.route('/')
@@ -610,33 +511,32 @@ def process_complete_utterance(client_id, client, speaker_id, is_incomplete=Fals
# Combine audio chunks # Combine audio chunks
full_audio = torch.cat(client['streaming_buffer'], dim=0) full_audio = torch.cat(client['streaming_buffer'], dim=0)
# Process with speech-to-text # Process audio to generate a response (no speech recognition)
logger.info(f"[{client_id[:8]}] Starting transcription...") generated_text = process_speech(full_audio, client_id)
transcribed_text = transcribe_audio(full_audio)
# Add suffix for incomplete utterances # Add suffix for incomplete utterances
if is_incomplete: if is_incomplete:
transcribed_text += " (processing continued speech...)" generated_text += " (processing continued speech...)"
# Log the transcription # Log the generated text
logger.info(f"[{client_id[:8]}] Transcribed: '{transcribed_text}'") logger.info(f"[{client_id[:8]}] Generated text: '{generated_text}'")
# Handle the transcription result # Handle the result
if transcribed_text: if generated_text:
# Add user message to context # Add user message to context
user_segment = Segment(text=transcribed_text, speaker=speaker_id, audio=full_audio) user_segment = Segment(text=generated_text, speaker=speaker_id, audio=full_audio)
client['context_segments'].append(user_segment) client['context_segments'].append(user_segment)
# Send the transcribed text to client # Send the text to client
emit('transcription', { emit('transcription', {
'type': 'transcription', 'type': 'transcription',
'text': transcribed_text 'text': generated_text
}, room=client_id) }, room=client_id)
# Only generate a response if this is a complete utterance # Only generate a response if this is a complete utterance
if not is_incomplete: if not is_incomplete:
# Generate a contextual response # Generate a contextual response
response_text = generate_response(transcribed_text, client['context_segments']) response_text = generate_response(generated_text, client['context_segments'])
logger.info(f"[{client_id[:8]}] Generating response: '{response_text}'") logger.info(f"[{client_id[:8]}] Generating response: '{response_text}'")
# Let the client know we're processing # Let the client know we're processing
@@ -684,7 +584,7 @@ def process_complete_utterance(client_id, client, speaker_id, is_incomplete=Fals
'message': "Sorry, there was an error generating the audio response." 'message': "Sorry, there was an error generating the audio response."
}, room=client_id) }, room=client_id)
else: else:
# If transcription failed, send a notification # If processing failed, send a notification
emit('error', { emit('error', {
'type': 'error', 'type': 'error',
'message': "Sorry, I couldn't understand what you said. Could you try again?" 'message': "Sorry, I couldn't understand what you said. Could you try again?"
@@ -791,7 +691,7 @@ if __name__ == "__main__":
print(f" - Network URL: http://<your-ip-address>:5000") print(f" - Network URL: http://<your-ip-address>:5000")
print(f"{'='*60}") print(f"{'='*60}")
print(f"🌐 Device: {device.upper()}") print(f"🌐 Device: {device.upper()}")
print(f"🧠 Models: Sesame CSM + WhisperX ASR") print(f"🧠 Models: Sesame CSM (TTS only)")
print(f"🔧 Serving from: {os.path.join(base_dir, 'index.html')}") print(f"🔧 Serving from: {os.path.join(base_dir, 'index.html')}")
print(f"{'='*60}") print(f"{'='*60}")
print(f"Ready to receive connections! Press Ctrl+C to stop the server.\n") print(f"Ready to receive connections! Press Ctrl+C to stop the server.\n")