Demo Update 4
This commit is contained in:
@@ -16,56 +16,91 @@ import gc
|
||||
from collections import deque
|
||||
from threading import Lock
|
||||
|
||||
# Add these lines right after your imports
|
||||
import torch
|
||||
import os
|
||||
# Add this at the top of your file, replacing your current CUDA setup
|
||||
|
||||
# Handle CUDA issues
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "0" # Limit to first GPU only
|
||||
torch.backends.cudnn.benchmark = True
|
||||
# CUDA setup with robust error handling
|
||||
try:
|
||||
# Handle CUDA issues
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "0" # Limit to first GPU only
|
||||
|
||||
# Set CUDA settings to avoid TF32 warnings
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
# Try enabling TF32 precision
|
||||
try:
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
except:
|
||||
pass # Ignore if not supported
|
||||
|
||||
# Set compute type based on available hardware
|
||||
if torch.cuda.is_available():
|
||||
# Check if CUDA is available
|
||||
if torch.cuda.is_available():
|
||||
try:
|
||||
# Test CUDA functionality
|
||||
x = torch.rand(10, device="cuda")
|
||||
y = x + x
|
||||
del x, y
|
||||
device = "cuda"
|
||||
compute_type = "float16" # Faster for CUDA
|
||||
else:
|
||||
compute_type = "float16"
|
||||
print("CUDA is fully functional")
|
||||
except Exception as cuda_error:
|
||||
print(f"CUDA is available but not working correctly: {str(cuda_error)}")
|
||||
device = "cpu"
|
||||
compute_type = "int8" # Better for CPU
|
||||
compute_type = "int8"
|
||||
else:
|
||||
device = "cpu"
|
||||
compute_type = "int8"
|
||||
except Exception as e:
|
||||
print(f"Error setting up CUDA: {str(e)}")
|
||||
device = "cpu"
|
||||
compute_type = "int8"
|
||||
|
||||
print(f"Using device: {device} with compute type: {compute_type}")
|
||||
|
||||
# Select device
|
||||
if torch.cuda.is_available():
|
||||
device = "cuda"
|
||||
else:
|
||||
device = "cpu"
|
||||
print(f"Using device: {device}")
|
||||
# Initialize the Sesame CSM model with robust error handling
|
||||
try:
|
||||
print(f"Loading Sesame CSM model on {device}...")
|
||||
generator = load_csm_1b(device=device)
|
||||
print("Sesame CSM model loaded successfully")
|
||||
except Exception as model_error:
|
||||
print(f"Error loading Sesame CSM on {device}: {str(model_error)}")
|
||||
if device == "cuda":
|
||||
# Try on CPU as fallback
|
||||
try:
|
||||
print("Trying to load Sesame CSM on CPU instead...")
|
||||
device = "cpu" # Update global device setting
|
||||
generator = load_csm_1b(device="cpu")
|
||||
print("Sesame CSM model loaded on CPU successfully")
|
||||
except Exception as cpu_error:
|
||||
print(f"Fatal error - could not load Sesame CSM model: {str(cpu_error)}")
|
||||
raise RuntimeError("Failed to load speech synthesis model")
|
||||
else:
|
||||
# Already tried CPU and it failed
|
||||
raise RuntimeError("Failed to load speech synthesis model on any device")
|
||||
|
||||
# Initialize the model
|
||||
generator = load_csm_1b(device=device)
|
||||
|
||||
# Initialize WhisperX for ASR
|
||||
# Initialize WhisperX for ASR with robust error handling
|
||||
print("Loading WhisperX model...")
|
||||
try:
|
||||
# Try to load a smaller model for faster response times
|
||||
# First try the smallest model ("tiny") to avoid memory issues
|
||||
asr_model = whisperx.load_model("tiny", device, compute_type=compute_type)
|
||||
print("WhisperX 'tiny' model loaded successfully")
|
||||
|
||||
# If tiny worked and we have CUDA, try upgrading to small
|
||||
if device == "cuda":
|
||||
try:
|
||||
asr_model = whisperx.load_model("small", device, compute_type=compute_type)
|
||||
print("WhisperX 'small' model loaded successfully")
|
||||
except Exception as upgrade_error:
|
||||
print(f"Staying with 'tiny' model: {str(upgrade_error)}")
|
||||
except Exception as e:
|
||||
print(f"Error loading 'small' model: {str(e)}")
|
||||
print(f"Error loading models on {device}: {str(e)}")
|
||||
print("Falling back to CPU model")
|
||||
try:
|
||||
# Fall back to tiny model if small fails
|
||||
asr_model = whisperx.load_model("tiny", device, compute_type=compute_type)
|
||||
print("WhisperX 'tiny' model loaded as fallback")
|
||||
except Exception as e2:
|
||||
print(f"Error loading fallback model: {str(e2)}")
|
||||
print("Trying CPU model as last resort")
|
||||
# Last resort - try CPU
|
||||
# Force CPU as last resort
|
||||
device = "cpu"
|
||||
compute_type = "int8"
|
||||
asr_model = whisperx.load_model("tiny", "cpu", compute_type="int8")
|
||||
print("WhisperX loaded on CPU as last resort")
|
||||
except Exception as cpu_error:
|
||||
print(f"Fatal error - could not load any model: {str(cpu_error)}")
|
||||
raise RuntimeError("No ASR model could be loaded. Please check your CUDA installation.")
|
||||
|
||||
# Silence detection parameters
|
||||
SILENCE_THRESHOLD = 0.01 # Adjust based on your audio normalization
|
||||
@@ -226,7 +261,7 @@ def encode_audio_data(audio_tensor: torch.Tensor) -> str:
|
||||
|
||||
|
||||
def transcribe_audio(audio_tensor: torch.Tensor) -> str:
|
||||
"""Transcribe audio using WhisperX"""
|
||||
"""Transcribe audio using WhisperX with robust error handling"""
|
||||
try:
|
||||
# Save the tensor to a temporary file
|
||||
temp_path = os.path.join(base_dir, "temp_audio.wav")
|
||||
@@ -234,9 +269,38 @@ def transcribe_audio(audio_tensor: torch.Tensor) -> str:
|
||||
|
||||
print(f"Transcribing audio file: {temp_path} (size: {os.path.getsize(temp_path)} bytes)")
|
||||
|
||||
# Load and transcribe the audio
|
||||
# Load the audio file using whisperx's function
|
||||
try:
|
||||
audio = whisperx.load_audio(temp_path)
|
||||
result = asr_model.transcribe(audio, batch_size=16)
|
||||
except Exception as audio_load_error:
|
||||
print(f"WhisperX load_audio failed: {str(audio_load_error)}")
|
||||
# Fall back to manual loading
|
||||
import soundfile as sf
|
||||
audio, sr = sf.read(temp_path)
|
||||
if sr != 16000: # WhisperX expects 16kHz audio
|
||||
from scipy import signal
|
||||
audio = signal.resample(audio, int(len(audio) * 16000 / sr))
|
||||
|
||||
# Transcribe with error handling for CUDA issues
|
||||
try:
|
||||
# Try with original device
|
||||
result = asr_model.transcribe(audio, batch_size=8)
|
||||
except RuntimeError as cuda_error:
|
||||
if "CUDA" in str(cuda_error) or "libcudnn" in str(cuda_error):
|
||||
print(f"CUDA error in transcription, falling back to CPU: {str(cuda_error)}")
|
||||
|
||||
# Try to load a CPU model as fallback
|
||||
try:
|
||||
global asr_model
|
||||
# Move model to CPU and try again
|
||||
asr_model = whisperx.load_model("tiny", "cpu", compute_type="int8")
|
||||
result = asr_model.transcribe(audio, batch_size=1)
|
||||
except Exception as e:
|
||||
print(f"CPU fallback also failed: {str(e)}")
|
||||
return "I'm having trouble processing audio right now."
|
||||
else:
|
||||
# Re-raise if it's not a CUDA error
|
||||
raise
|
||||
|
||||
# Clean up
|
||||
if os.path.exists(temp_path):
|
||||
@@ -257,7 +321,7 @@ def transcribe_audio(audio_tensor: torch.Tensor) -> str:
|
||||
traceback.print_exc()
|
||||
if os.path.exists("temp_audio.wav"):
|
||||
os.remove("temp_audio.wav")
|
||||
return ""
|
||||
return "I heard something but couldn't understand it."
|
||||
|
||||
|
||||
def generate_response(text: str, conversation_history: List[Segment]) -> str:
|
||||
|
||||
Reference in New Issue
Block a user