Demo Update 4
This commit is contained in:
@@ -16,56 +16,91 @@ import gc
|
|||||||
from collections import deque
|
from collections import deque
|
||||||
from threading import Lock
|
from threading import Lock
|
||||||
|
|
||||||
# Add these lines right after your imports
|
# Add this at the top of your file, replacing your current CUDA setup
|
||||||
import torch
|
|
||||||
import os
|
|
||||||
|
|
||||||
# Handle CUDA issues
|
# CUDA setup with robust error handling
|
||||||
os.environ["CUDA_VISIBLE_DEVICES"] = "0" # Limit to first GPU only
|
try:
|
||||||
torch.backends.cudnn.benchmark = True
|
# Handle CUDA issues
|
||||||
|
os.environ["CUDA_VISIBLE_DEVICES"] = "0" # Limit to first GPU only
|
||||||
|
|
||||||
# Set CUDA settings to avoid TF32 warnings
|
# Try enabling TF32 precision
|
||||||
torch.backends.cuda.matmul.allow_tf32 = True
|
try:
|
||||||
torch.backends.cudnn.allow_tf32 = True
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
|
torch.backends.cudnn.allow_tf32 = True
|
||||||
|
except:
|
||||||
|
pass # Ignore if not supported
|
||||||
|
|
||||||
# Set compute type based on available hardware
|
# Check if CUDA is available
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
|
try:
|
||||||
|
# Test CUDA functionality
|
||||||
|
x = torch.rand(10, device="cuda")
|
||||||
|
y = x + x
|
||||||
|
del x, y
|
||||||
device = "cuda"
|
device = "cuda"
|
||||||
compute_type = "float16" # Faster for CUDA
|
compute_type = "float16"
|
||||||
else:
|
print("CUDA is fully functional")
|
||||||
|
except Exception as cuda_error:
|
||||||
|
print(f"CUDA is available but not working correctly: {str(cuda_error)}")
|
||||||
device = "cpu"
|
device = "cpu"
|
||||||
compute_type = "int8" # Better for CPU
|
compute_type = "int8"
|
||||||
|
else:
|
||||||
|
device = "cpu"
|
||||||
|
compute_type = "int8"
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error setting up CUDA: {str(e)}")
|
||||||
|
device = "cpu"
|
||||||
|
compute_type = "int8"
|
||||||
|
|
||||||
print(f"Using device: {device} with compute type: {compute_type}")
|
print(f"Using device: {device} with compute type: {compute_type}")
|
||||||
|
|
||||||
# Select device
|
# Initialize the Sesame CSM model with robust error handling
|
||||||
if torch.cuda.is_available():
|
try:
|
||||||
device = "cuda"
|
print(f"Loading Sesame CSM model on {device}...")
|
||||||
else:
|
generator = load_csm_1b(device=device)
|
||||||
device = "cpu"
|
print("Sesame CSM model loaded successfully")
|
||||||
print(f"Using device: {device}")
|
except Exception as model_error:
|
||||||
|
print(f"Error loading Sesame CSM on {device}: {str(model_error)}")
|
||||||
|
if device == "cuda":
|
||||||
|
# Try on CPU as fallback
|
||||||
|
try:
|
||||||
|
print("Trying to load Sesame CSM on CPU instead...")
|
||||||
|
device = "cpu" # Update global device setting
|
||||||
|
generator = load_csm_1b(device="cpu")
|
||||||
|
print("Sesame CSM model loaded on CPU successfully")
|
||||||
|
except Exception as cpu_error:
|
||||||
|
print(f"Fatal error - could not load Sesame CSM model: {str(cpu_error)}")
|
||||||
|
raise RuntimeError("Failed to load speech synthesis model")
|
||||||
|
else:
|
||||||
|
# Already tried CPU and it failed
|
||||||
|
raise RuntimeError("Failed to load speech synthesis model on any device")
|
||||||
|
|
||||||
# Initialize the model
|
# Initialize WhisperX for ASR with robust error handling
|
||||||
generator = load_csm_1b(device=device)
|
|
||||||
|
|
||||||
# Initialize WhisperX for ASR
|
|
||||||
print("Loading WhisperX model...")
|
print("Loading WhisperX model...")
|
||||||
try:
|
try:
|
||||||
# Try to load a smaller model for faster response times
|
# First try the smallest model ("tiny") to avoid memory issues
|
||||||
|
asr_model = whisperx.load_model("tiny", device, compute_type=compute_type)
|
||||||
|
print("WhisperX 'tiny' model loaded successfully")
|
||||||
|
|
||||||
|
# If tiny worked and we have CUDA, try upgrading to small
|
||||||
|
if device == "cuda":
|
||||||
|
try:
|
||||||
asr_model = whisperx.load_model("small", device, compute_type=compute_type)
|
asr_model = whisperx.load_model("small", device, compute_type=compute_type)
|
||||||
print("WhisperX 'small' model loaded successfully")
|
print("WhisperX 'small' model loaded successfully")
|
||||||
|
except Exception as upgrade_error:
|
||||||
|
print(f"Staying with 'tiny' model: {str(upgrade_error)}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error loading 'small' model: {str(e)}")
|
print(f"Error loading models on {device}: {str(e)}")
|
||||||
|
print("Falling back to CPU model")
|
||||||
try:
|
try:
|
||||||
# Fall back to tiny model if small fails
|
# Force CPU as last resort
|
||||||
asr_model = whisperx.load_model("tiny", device, compute_type=compute_type)
|
device = "cpu"
|
||||||
print("WhisperX 'tiny' model loaded as fallback")
|
compute_type = "int8"
|
||||||
except Exception as e2:
|
|
||||||
print(f"Error loading fallback model: {str(e2)}")
|
|
||||||
print("Trying CPU model as last resort")
|
|
||||||
# Last resort - try CPU
|
|
||||||
asr_model = whisperx.load_model("tiny", "cpu", compute_type="int8")
|
asr_model = whisperx.load_model("tiny", "cpu", compute_type="int8")
|
||||||
print("WhisperX loaded on CPU as last resort")
|
print("WhisperX loaded on CPU as last resort")
|
||||||
|
except Exception as cpu_error:
|
||||||
|
print(f"Fatal error - could not load any model: {str(cpu_error)}")
|
||||||
|
raise RuntimeError("No ASR model could be loaded. Please check your CUDA installation.")
|
||||||
|
|
||||||
# Silence detection parameters
|
# Silence detection parameters
|
||||||
SILENCE_THRESHOLD = 0.01 # Adjust based on your audio normalization
|
SILENCE_THRESHOLD = 0.01 # Adjust based on your audio normalization
|
||||||
@@ -226,7 +261,7 @@ def encode_audio_data(audio_tensor: torch.Tensor) -> str:
|
|||||||
|
|
||||||
|
|
||||||
def transcribe_audio(audio_tensor: torch.Tensor) -> str:
|
def transcribe_audio(audio_tensor: torch.Tensor) -> str:
|
||||||
"""Transcribe audio using WhisperX"""
|
"""Transcribe audio using WhisperX with robust error handling"""
|
||||||
try:
|
try:
|
||||||
# Save the tensor to a temporary file
|
# Save the tensor to a temporary file
|
||||||
temp_path = os.path.join(base_dir, "temp_audio.wav")
|
temp_path = os.path.join(base_dir, "temp_audio.wav")
|
||||||
@@ -234,9 +269,38 @@ def transcribe_audio(audio_tensor: torch.Tensor) -> str:
|
|||||||
|
|
||||||
print(f"Transcribing audio file: {temp_path} (size: {os.path.getsize(temp_path)} bytes)")
|
print(f"Transcribing audio file: {temp_path} (size: {os.path.getsize(temp_path)} bytes)")
|
||||||
|
|
||||||
# Load and transcribe the audio
|
# Load the audio file using whisperx's function
|
||||||
|
try:
|
||||||
audio = whisperx.load_audio(temp_path)
|
audio = whisperx.load_audio(temp_path)
|
||||||
result = asr_model.transcribe(audio, batch_size=16)
|
except Exception as audio_load_error:
|
||||||
|
print(f"WhisperX load_audio failed: {str(audio_load_error)}")
|
||||||
|
# Fall back to manual loading
|
||||||
|
import soundfile as sf
|
||||||
|
audio, sr = sf.read(temp_path)
|
||||||
|
if sr != 16000: # WhisperX expects 16kHz audio
|
||||||
|
from scipy import signal
|
||||||
|
audio = signal.resample(audio, int(len(audio) * 16000 / sr))
|
||||||
|
|
||||||
|
# Transcribe with error handling for CUDA issues
|
||||||
|
try:
|
||||||
|
# Try with original device
|
||||||
|
result = asr_model.transcribe(audio, batch_size=8)
|
||||||
|
except RuntimeError as cuda_error:
|
||||||
|
if "CUDA" in str(cuda_error) or "libcudnn" in str(cuda_error):
|
||||||
|
print(f"CUDA error in transcription, falling back to CPU: {str(cuda_error)}")
|
||||||
|
|
||||||
|
# Try to load a CPU model as fallback
|
||||||
|
try:
|
||||||
|
global asr_model
|
||||||
|
# Move model to CPU and try again
|
||||||
|
asr_model = whisperx.load_model("tiny", "cpu", compute_type="int8")
|
||||||
|
result = asr_model.transcribe(audio, batch_size=1)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"CPU fallback also failed: {str(e)}")
|
||||||
|
return "I'm having trouble processing audio right now."
|
||||||
|
else:
|
||||||
|
# Re-raise if it's not a CUDA error
|
||||||
|
raise
|
||||||
|
|
||||||
# Clean up
|
# Clean up
|
||||||
if os.path.exists(temp_path):
|
if os.path.exists(temp_path):
|
||||||
@@ -257,7 +321,7 @@ def transcribe_audio(audio_tensor: torch.Tensor) -> str:
|
|||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
if os.path.exists("temp_audio.wav"):
|
if os.path.exists("temp_audio.wav"):
|
||||||
os.remove("temp_audio.wav")
|
os.remove("temp_audio.wav")
|
||||||
return ""
|
return "I heard something but couldn't understand it."
|
||||||
|
|
||||||
|
|
||||||
def generate_response(text: str, conversation_history: List[Segment]) -> str:
|
def generate_response(text: str, conversation_history: List[Segment]) -> str:
|
||||||
|
|||||||
Reference in New Issue
Block a user