Demo Update 4

This commit is contained in:
2025-03-30 00:14:47 -04:00
parent eef7da454a
commit 230117a022

View File

@@ -16,56 +16,91 @@ import gc
from collections import deque from collections import deque
from threading import Lock from threading import Lock
# Add these lines right after your imports # Add this at the top of your file, replacing your current CUDA setup
import torch
import os
# CUDA setup with robust error handling
try:
# Handle CUDA issues # Handle CUDA issues
os.environ["CUDA_VISIBLE_DEVICES"] = "0" # Limit to first GPU only os.environ["CUDA_VISIBLE_DEVICES"] = "0" # Limit to first GPU only
torch.backends.cudnn.benchmark = True
# Set CUDA settings to avoid TF32 warnings # Try enabling TF32 precision
try:
torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True
except:
pass # Ignore if not supported
# Set compute type based on available hardware # Check if CUDA is available
if torch.cuda.is_available(): if torch.cuda.is_available():
try:
# Test CUDA functionality
x = torch.rand(10, device="cuda")
y = x + x
del x, y
device = "cuda" device = "cuda"
compute_type = "float16" # Faster for CUDA compute_type = "float16"
print("CUDA is fully functional")
except Exception as cuda_error:
print(f"CUDA is available but not working correctly: {str(cuda_error)}")
device = "cpu"
compute_type = "int8"
else: else:
device = "cpu" device = "cpu"
compute_type = "int8" # Better for CPU compute_type = "int8"
except Exception as e:
print(f"Error setting up CUDA: {str(e)}")
device = "cpu"
compute_type = "int8"
print(f"Using device: {device} with compute type: {compute_type}") print(f"Using device: {device} with compute type: {compute_type}")
# Select device # Initialize the Sesame CSM model with robust error handling
if torch.cuda.is_available(): try:
device = "cuda" print(f"Loading Sesame CSM model on {device}...")
else:
device = "cpu"
print(f"Using device: {device}")
# Initialize the model
generator = load_csm_1b(device=device) generator = load_csm_1b(device=device)
print("Sesame CSM model loaded successfully")
except Exception as model_error:
print(f"Error loading Sesame CSM on {device}: {str(model_error)}")
if device == "cuda":
# Try on CPU as fallback
try:
print("Trying to load Sesame CSM on CPU instead...")
device = "cpu" # Update global device setting
generator = load_csm_1b(device="cpu")
print("Sesame CSM model loaded on CPU successfully")
except Exception as cpu_error:
print(f"Fatal error - could not load Sesame CSM model: {str(cpu_error)}")
raise RuntimeError("Failed to load speech synthesis model")
else:
# Already tried CPU and it failed
raise RuntimeError("Failed to load speech synthesis model on any device")
# Initialize WhisperX for ASR # Initialize WhisperX for ASR with robust error handling
print("Loading WhisperX model...") print("Loading WhisperX model...")
try: try:
# Try to load a smaller model for faster response times # First try the smallest model ("tiny") to avoid memory issues
asr_model = whisperx.load_model("tiny", device, compute_type=compute_type)
print("WhisperX 'tiny' model loaded successfully")
# If tiny worked and we have CUDA, try upgrading to small
if device == "cuda":
try:
asr_model = whisperx.load_model("small", device, compute_type=compute_type) asr_model = whisperx.load_model("small", device, compute_type=compute_type)
print("WhisperX 'small' model loaded successfully") print("WhisperX 'small' model loaded successfully")
except Exception as upgrade_error:
print(f"Staying with 'tiny' model: {str(upgrade_error)}")
except Exception as e: except Exception as e:
print(f"Error loading 'small' model: {str(e)}") print(f"Error loading models on {device}: {str(e)}")
print("Falling back to CPU model")
try: try:
# Fall back to tiny model if small fails # Force CPU as last resort
asr_model = whisperx.load_model("tiny", device, compute_type=compute_type) device = "cpu"
print("WhisperX 'tiny' model loaded as fallback") compute_type = "int8"
except Exception as e2:
print(f"Error loading fallback model: {str(e2)}")
print("Trying CPU model as last resort")
# Last resort - try CPU
asr_model = whisperx.load_model("tiny", "cpu", compute_type="int8") asr_model = whisperx.load_model("tiny", "cpu", compute_type="int8")
print("WhisperX loaded on CPU as last resort") print("WhisperX loaded on CPU as last resort")
except Exception as cpu_error:
print(f"Fatal error - could not load any model: {str(cpu_error)}")
raise RuntimeError("No ASR model could be loaded. Please check your CUDA installation.")
# Silence detection parameters # Silence detection parameters
SILENCE_THRESHOLD = 0.01 # Adjust based on your audio normalization SILENCE_THRESHOLD = 0.01 # Adjust based on your audio normalization
@@ -226,7 +261,7 @@ def encode_audio_data(audio_tensor: torch.Tensor) -> str:
def transcribe_audio(audio_tensor: torch.Tensor) -> str: def transcribe_audio(audio_tensor: torch.Tensor) -> str:
"""Transcribe audio using WhisperX""" """Transcribe audio using WhisperX with robust error handling"""
try: try:
# Save the tensor to a temporary file # Save the tensor to a temporary file
temp_path = os.path.join(base_dir, "temp_audio.wav") temp_path = os.path.join(base_dir, "temp_audio.wav")
@@ -234,9 +269,38 @@ def transcribe_audio(audio_tensor: torch.Tensor) -> str:
print(f"Transcribing audio file: {temp_path} (size: {os.path.getsize(temp_path)} bytes)") print(f"Transcribing audio file: {temp_path} (size: {os.path.getsize(temp_path)} bytes)")
# Load and transcribe the audio # Load the audio file using whisperx's function
try:
audio = whisperx.load_audio(temp_path) audio = whisperx.load_audio(temp_path)
result = asr_model.transcribe(audio, batch_size=16) except Exception as audio_load_error:
print(f"WhisperX load_audio failed: {str(audio_load_error)}")
# Fall back to manual loading
import soundfile as sf
audio, sr = sf.read(temp_path)
if sr != 16000: # WhisperX expects 16kHz audio
from scipy import signal
audio = signal.resample(audio, int(len(audio) * 16000 / sr))
# Transcribe with error handling for CUDA issues
try:
# Try with original device
result = asr_model.transcribe(audio, batch_size=8)
except RuntimeError as cuda_error:
if "CUDA" in str(cuda_error) or "libcudnn" in str(cuda_error):
print(f"CUDA error in transcription, falling back to CPU: {str(cuda_error)}")
# Try to load a CPU model as fallback
try:
global asr_model
# Move model to CPU and try again
asr_model = whisperx.load_model("tiny", "cpu", compute_type="int8")
result = asr_model.transcribe(audio, batch_size=1)
except Exception as e:
print(f"CPU fallback also failed: {str(e)}")
return "I'm having trouble processing audio right now."
else:
# Re-raise if it's not a CUDA error
raise
# Clean up # Clean up
if os.path.exists(temp_path): if os.path.exists(temp_path):
@@ -257,7 +321,7 @@ def transcribe_audio(audio_tensor: torch.Tensor) -> str:
traceback.print_exc() traceback.print_exc()
if os.path.exists("temp_audio.wav"): if os.path.exists("temp_audio.wav"):
os.remove("temp_audio.wav") os.remove("temp_audio.wav")
return "" return "I heard something but couldn't understand it."
def generate_response(text: str, conversation_history: List[Segment]) -> str: def generate_response(text: str, conversation_history: List[Segment]) -> str: