Demo Update 7
This commit is contained in:
@@ -8,7 +8,6 @@ import logging
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torchaudio
|
import torchaudio
|
||||||
import whisperx
|
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from typing import List, Dict, Any, Optional
|
from typing import List, Dict, Any, Optional
|
||||||
from flask import Flask, request, send_from_directory, Response
|
from flask import Flask, request, send_from_directory, Response
|
||||||
@@ -25,68 +24,24 @@ logging.basicConfig(
|
|||||||
)
|
)
|
||||||
logger = logging.getLogger("sesame-server")
|
logger = logging.getLogger("sesame-server")
|
||||||
|
|
||||||
# CUDA Environment Setup
|
# Determine best compute device
|
||||||
def setup_cuda_environment():
|
if torch.backends.mps.is_available():
|
||||||
"""Set up CUDA environment with proper error handling"""
|
device = "mps"
|
||||||
# Search for CUDA libraries in common locations
|
elif torch.cuda.is_available():
|
||||||
cuda_lib_dirs = [
|
|
||||||
"/usr/local/cuda/lib64",
|
|
||||||
"/usr/lib/x86_64-linux-gnu",
|
|
||||||
"/usr/local/cuda/extras/CUPTI/lib64"
|
|
||||||
]
|
|
||||||
|
|
||||||
# Add directories to LD_LIBRARY_PATH if they exist
|
|
||||||
current_ld_path = os.environ.get('LD_LIBRARY_PATH', '')
|
|
||||||
for cuda_dir in cuda_lib_dirs:
|
|
||||||
if os.path.exists(cuda_dir) and cuda_dir not in current_ld_path:
|
|
||||||
if current_ld_path:
|
|
||||||
os.environ['LD_LIBRARY_PATH'] = f"{current_ld_path}:{cuda_dir}"
|
|
||||||
else:
|
|
||||||
os.environ['LD_LIBRARY_PATH'] = cuda_dir
|
|
||||||
current_ld_path = os.environ['LD_LIBRARY_PATH']
|
|
||||||
|
|
||||||
logger.info(f"LD_LIBRARY_PATH set to: {os.environ.get('LD_LIBRARY_PATH', 'not set')}")
|
|
||||||
|
|
||||||
# Determine best compute device
|
|
||||||
device = "cpu"
|
|
||||||
compute_type = "int8"
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Set CUDA preferences
|
|
||||||
os.environ["CUDA_VISIBLE_DEVICES"] = "0" # Limit to first GPU only
|
|
||||||
|
|
||||||
# Try enabling TF32 precision if available
|
|
||||||
try:
|
try:
|
||||||
|
# Test CUDA functionality
|
||||||
|
torch.rand(10, device="cuda")
|
||||||
torch.backends.cuda.matmul.allow_tf32 = True
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
torch.backends.cudnn.allow_tf32 = True
|
torch.backends.cudnn.allow_tf32 = True
|
||||||
torch.backends.cudnn.enabled = True
|
|
||||||
torch.backends.cudnn.benchmark = True
|
torch.backends.cudnn.benchmark = True
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Could not set advanced CUDA options: {e}")
|
|
||||||
|
|
||||||
# Test if CUDA is functional
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
try:
|
|
||||||
# Test basic CUDA operations
|
|
||||||
x = torch.rand(10, device="cuda")
|
|
||||||
y = x + x
|
|
||||||
del x, y
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
device = "cuda"
|
device = "cuda"
|
||||||
compute_type = "float16"
|
|
||||||
logger.info("CUDA is fully functional")
|
logger.info("CUDA is fully functional")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"CUDA available but not working correctly: {e}")
|
logger.warning(f"CUDA available but not working correctly: {e}")
|
||||||
device = "cpu"
|
device = "cpu"
|
||||||
else:
|
else:
|
||||||
logger.info("CUDA is not available, using CPU")
|
device = "cpu"
|
||||||
except Exception as e:
|
logger.info("Using CPU")
|
||||||
logger.error(f"Error setting up computing environment: {e}")
|
|
||||||
|
|
||||||
return device, compute_type
|
|
||||||
|
|
||||||
# Set up the compute environment
|
|
||||||
device, compute_type = setup_cuda_environment()
|
|
||||||
|
|
||||||
# Constants and Configuration
|
# Constants and Configuration
|
||||||
SILENCE_THRESHOLD = 0.01
|
SILENCE_THRESHOLD = 0.01
|
||||||
@@ -99,9 +54,37 @@ base_dir = os.path.dirname(os.path.abspath(__file__))
|
|||||||
static_dir = os.path.join(base_dir, "static")
|
static_dir = os.path.join(base_dir, "static")
|
||||||
os.makedirs(static_dir, exist_ok=True)
|
os.makedirs(static_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# Define a simple energy-based speech detector
|
||||||
|
class SpeechDetector:
|
||||||
|
def __init__(self):
|
||||||
|
self.min_speech_energy = 0.01
|
||||||
|
self.speech_window = 0.2 # seconds
|
||||||
|
|
||||||
|
def detect_speech(self, audio_tensor, sample_rate):
|
||||||
|
# Calculate frame size based on window size
|
||||||
|
frame_size = int(sample_rate * self.speech_window)
|
||||||
|
|
||||||
|
# If audio is shorter than frame size, use the entire audio
|
||||||
|
if audio_tensor.shape[0] < frame_size:
|
||||||
|
frames = [audio_tensor]
|
||||||
|
else:
|
||||||
|
# Split audio into frames
|
||||||
|
frames = [audio_tensor[i:i+frame_size] for i in range(0, len(audio_tensor), frame_size)]
|
||||||
|
|
||||||
|
# Calculate energy per frame
|
||||||
|
energies = [torch.mean(frame**2).item() for frame in frames]
|
||||||
|
|
||||||
|
# Determine if there's speech based on energy threshold
|
||||||
|
has_speech = any(e > self.min_speech_energy for e in energies)
|
||||||
|
|
||||||
|
return has_speech
|
||||||
|
|
||||||
|
speech_detector = SpeechDetector()
|
||||||
|
logger.info("Initialized simple speech detector")
|
||||||
|
|
||||||
# Model Loading Functions
|
# Model Loading Functions
|
||||||
def load_speech_models():
|
def load_speech_models():
|
||||||
"""Load all required speech models with fallbacks"""
|
"""Load speech generation model"""
|
||||||
# Load speech generation model (Sesame CSM)
|
# Load speech generation model (Sesame CSM)
|
||||||
try:
|
try:
|
||||||
logger.info(f"Loading Sesame CSM model on {device}...")
|
logger.info(f"Loading Sesame CSM model on {device}...")
|
||||||
@@ -120,52 +103,10 @@ def load_speech_models():
|
|||||||
else:
|
else:
|
||||||
raise RuntimeError("Failed to load speech synthesis model on any device")
|
raise RuntimeError("Failed to load speech synthesis model on any device")
|
||||||
|
|
||||||
# Load ASR model (WhisperX)
|
return generator
|
||||||
try:
|
|
||||||
logger.info("Loading WhisperX model...")
|
|
||||||
# Start with the tiny model on CPU for reliable initialization
|
|
||||||
asr_model = whisperx.load_model("tiny", "cpu", compute_type="int8")
|
|
||||||
logger.info("WhisperX 'tiny' model loaded on CPU successfully")
|
|
||||||
|
|
||||||
# Try upgrading to GPU if available
|
# Load speech model
|
||||||
if device == "cuda":
|
generator = load_speech_models()
|
||||||
try:
|
|
||||||
logger.info("Trying to load WhisperX on CUDA...")
|
|
||||||
# Test with a tiny model first
|
|
||||||
test_audio = torch.zeros(16000) # 1 second of silence
|
|
||||||
|
|
||||||
cuda_model = whisperx.load_model("tiny", "cuda", compute_type="float16")
|
|
||||||
# Test the model with real inference
|
|
||||||
_ = cuda_model.transcribe(test_audio.numpy(), batch_size=1)
|
|
||||||
asr_model = cuda_model
|
|
||||||
logger.info("WhisperX model running on CUDA successfully")
|
|
||||||
|
|
||||||
# Try to upgrade to small model
|
|
||||||
try:
|
|
||||||
small_model = whisperx.load_model("small", "cuda", compute_type="float16")
|
|
||||||
_ = small_model.transcribe(test_audio.numpy(), batch_size=1)
|
|
||||||
asr_model = small_model
|
|
||||||
logger.info("WhisperX 'small' model loaded on CUDA successfully")
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Staying with 'tiny' model on CUDA: {e}")
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"CUDA loading failed, staying with CPU model: {e}")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error loading WhisperX model: {e}")
|
|
||||||
# Create a minimal dummy model as last resort
|
|
||||||
class DummyModel:
|
|
||||||
def __init__(self):
|
|
||||||
self.device = "cpu"
|
|
||||||
def transcribe(self, *args, **kwargs):
|
|
||||||
return {"segments": [{"text": "Speech recognition currently unavailable."}]}
|
|
||||||
|
|
||||||
asr_model = DummyModel()
|
|
||||||
logger.warning("Using dummy transcription model - ASR functionality limited")
|
|
||||||
|
|
||||||
return generator, asr_model
|
|
||||||
|
|
||||||
# Load speech models
|
|
||||||
generator, asr_model = load_speech_models()
|
|
||||||
|
|
||||||
# Set up Flask and Socket.IO
|
# Set up Flask and Socket.IO
|
||||||
app = Flask(__name__)
|
app = Flask(__name__)
|
||||||
@@ -307,63 +248,23 @@ def encode_audio_data(audio_tensor: torch.Tensor) -> str:
|
|||||||
buf.seek(0)
|
buf.seek(0)
|
||||||
return f"data:audio/wav;base64,{base64.b64encode(buf.read()).decode('utf-8')}"
|
return f"data:audio/wav;base64,{base64.b64encode(buf.read()).decode('utf-8')}"
|
||||||
|
|
||||||
def transcribe_audio(audio_tensor: torch.Tensor) -> str:
|
def process_speech(audio_tensor: torch.Tensor, client_id: str) -> str:
|
||||||
"""Transcribe audio using WhisperX with robust error handling"""
|
"""Process speech and return a simple response"""
|
||||||
global asr_model
|
# In this simplified version, we'll just check if there's sound
|
||||||
|
# and provide basic responses instead of doing actual speech recognition
|
||||||
|
|
||||||
try:
|
if speech_detector and speech_detector.detect_speech(audio_tensor, generator.sample_rate):
|
||||||
# Save the tensor to a temporary file
|
# Generate a response based on audio energy
|
||||||
temp_path = os.path.join(base_dir, f"temp_audio_{time.time()}.wav")
|
energy = torch.mean(torch.abs(audio_tensor)).item()
|
||||||
torchaudio.save(temp_path, audio_tensor.unsqueeze(0).cpu(), generator.sample_rate)
|
|
||||||
|
|
||||||
logger.info(f"Transcribing audio file: {os.path.getsize(temp_path)} bytes")
|
if energy > 0.1: # Louder speech
|
||||||
|
return "I heard you speaking clearly. How can I help you today?"
|
||||||
# Load the audio for WhisperX
|
elif energy > 0.05: # Moderate speech
|
||||||
try:
|
return "I heard you say something. Could you please repeat that?"
|
||||||
audio = whisperx.load_audio(temp_path)
|
else: # Soft speech
|
||||||
except Exception as e:
|
return "I detected some speech, but it was quite soft. Could you speak up a bit?"
|
||||||
logger.warning(f"WhisperX load_audio failed: {e}")
|
|
||||||
# Fall back to manual loading
|
|
||||||
import soundfile as sf
|
|
||||||
audio, sr = sf.read(temp_path)
|
|
||||||
if sr != 16000: # WhisperX expects 16kHz audio
|
|
||||||
from scipy import signal
|
|
||||||
audio = signal.resample(audio, int(len(audio) * 16000 / sr))
|
|
||||||
|
|
||||||
# Transcribe with error handling
|
|
||||||
try:
|
|
||||||
result = asr_model.transcribe(audio, batch_size=4)
|
|
||||||
except RuntimeError as e:
|
|
||||||
if "CUDA" in str(e) or "libcudnn" in str(e):
|
|
||||||
logger.warning(f"CUDA error in transcription, falling back to CPU: {e}")
|
|
||||||
try:
|
|
||||||
# Try CPU model
|
|
||||||
cpu_model = whisperx.load_model("tiny", "cpu", compute_type="int8")
|
|
||||||
result = cpu_model.transcribe(audio, batch_size=1)
|
|
||||||
# Update the global model if the original one is broken
|
|
||||||
asr_model = cpu_model
|
|
||||||
except Exception as cpu_e:
|
|
||||||
logger.error(f"CPU fallback failed: {cpu_e}")
|
|
||||||
return "I'm having trouble processing audio right now."
|
|
||||||
else:
|
else:
|
||||||
raise
|
return "I didn't detect any speech. Could you please try again?"
|
||||||
finally:
|
|
||||||
# Clean up
|
|
||||||
if os.path.exists(temp_path):
|
|
||||||
os.remove(temp_path)
|
|
||||||
|
|
||||||
# Extract text from segments
|
|
||||||
if result["segments"] and len(result["segments"]) > 0:
|
|
||||||
transcription = " ".join([segment["text"] for segment in result["segments"]])
|
|
||||||
logger.info(f"Transcription: '{transcription.strip()}'")
|
|
||||||
return transcription.strip()
|
|
||||||
|
|
||||||
return ""
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error in transcription: {e}")
|
|
||||||
if os.path.exists(temp_path):
|
|
||||||
os.remove(temp_path)
|
|
||||||
return "I heard something but couldn't understand it."
|
|
||||||
|
|
||||||
def generate_response(text: str, conversation_history: List[Segment]) -> str:
|
def generate_response(text: str, conversation_history: List[Segment]) -> str:
|
||||||
"""Generate a contextual response based on the transcribed text"""
|
"""Generate a contextual response based on the transcribed text"""
|
||||||
@@ -394,7 +295,7 @@ def generate_response(text: str, conversation_history: List[Segment]) -> str:
|
|||||||
elif len(text) < 10:
|
elif len(text) < 10:
|
||||||
return "Thanks for your message. Could you elaborate a bit more?"
|
return "Thanks for your message. Could you elaborate a bit more?"
|
||||||
else:
|
else:
|
||||||
return f"I understand you said '{text}'. That's interesting! Can you tell me more about that?"
|
return f"I heard you speaking. That's interesting! Can you tell me more about that?"
|
||||||
|
|
||||||
# Flask Routes
|
# Flask Routes
|
||||||
@app.route('/')
|
@app.route('/')
|
||||||
@@ -610,33 +511,32 @@ def process_complete_utterance(client_id, client, speaker_id, is_incomplete=Fals
|
|||||||
# Combine audio chunks
|
# Combine audio chunks
|
||||||
full_audio = torch.cat(client['streaming_buffer'], dim=0)
|
full_audio = torch.cat(client['streaming_buffer'], dim=0)
|
||||||
|
|
||||||
# Process with speech-to-text
|
# Process audio to generate a response (no speech recognition)
|
||||||
logger.info(f"[{client_id[:8]}] Starting transcription...")
|
generated_text = process_speech(full_audio, client_id)
|
||||||
transcribed_text = transcribe_audio(full_audio)
|
|
||||||
|
|
||||||
# Add suffix for incomplete utterances
|
# Add suffix for incomplete utterances
|
||||||
if is_incomplete:
|
if is_incomplete:
|
||||||
transcribed_text += " (processing continued speech...)"
|
generated_text += " (processing continued speech...)"
|
||||||
|
|
||||||
# Log the transcription
|
# Log the generated text
|
||||||
logger.info(f"[{client_id[:8]}] Transcribed: '{transcribed_text}'")
|
logger.info(f"[{client_id[:8]}] Generated text: '{generated_text}'")
|
||||||
|
|
||||||
# Handle the transcription result
|
# Handle the result
|
||||||
if transcribed_text:
|
if generated_text:
|
||||||
# Add user message to context
|
# Add user message to context
|
||||||
user_segment = Segment(text=transcribed_text, speaker=speaker_id, audio=full_audio)
|
user_segment = Segment(text=generated_text, speaker=speaker_id, audio=full_audio)
|
||||||
client['context_segments'].append(user_segment)
|
client['context_segments'].append(user_segment)
|
||||||
|
|
||||||
# Send the transcribed text to client
|
# Send the text to client
|
||||||
emit('transcription', {
|
emit('transcription', {
|
||||||
'type': 'transcription',
|
'type': 'transcription',
|
||||||
'text': transcribed_text
|
'text': generated_text
|
||||||
}, room=client_id)
|
}, room=client_id)
|
||||||
|
|
||||||
# Only generate a response if this is a complete utterance
|
# Only generate a response if this is a complete utterance
|
||||||
if not is_incomplete:
|
if not is_incomplete:
|
||||||
# Generate a contextual response
|
# Generate a contextual response
|
||||||
response_text = generate_response(transcribed_text, client['context_segments'])
|
response_text = generate_response(generated_text, client['context_segments'])
|
||||||
logger.info(f"[{client_id[:8]}] Generating response: '{response_text}'")
|
logger.info(f"[{client_id[:8]}] Generating response: '{response_text}'")
|
||||||
|
|
||||||
# Let the client know we're processing
|
# Let the client know we're processing
|
||||||
@@ -684,7 +584,7 @@ def process_complete_utterance(client_id, client, speaker_id, is_incomplete=Fals
|
|||||||
'message': "Sorry, there was an error generating the audio response."
|
'message': "Sorry, there was an error generating the audio response."
|
||||||
}, room=client_id)
|
}, room=client_id)
|
||||||
else:
|
else:
|
||||||
# If transcription failed, send a notification
|
# If processing failed, send a notification
|
||||||
emit('error', {
|
emit('error', {
|
||||||
'type': 'error',
|
'type': 'error',
|
||||||
'message': "Sorry, I couldn't understand what you said. Could you try again?"
|
'message': "Sorry, I couldn't understand what you said. Could you try again?"
|
||||||
@@ -791,7 +691,7 @@ if __name__ == "__main__":
|
|||||||
print(f" - Network URL: http://<your-ip-address>:5000")
|
print(f" - Network URL: http://<your-ip-address>:5000")
|
||||||
print(f"{'='*60}")
|
print(f"{'='*60}")
|
||||||
print(f"🌐 Device: {device.upper()}")
|
print(f"🌐 Device: {device.upper()}")
|
||||||
print(f"🧠 Models: Sesame CSM + WhisperX ASR")
|
print(f"🧠 Models: Sesame CSM (TTS only)")
|
||||||
print(f"🔧 Serving from: {os.path.join(base_dir, 'index.html')}")
|
print(f"🔧 Serving from: {os.path.join(base_dir, 'index.html')}")
|
||||||
print(f"{'='*60}")
|
print(f"{'='*60}")
|
||||||
print(f"Ready to receive connections! Press Ctrl+C to stop the server.\n")
|
print(f"Ready to receive connections! Press Ctrl+C to stop the server.\n")
|
||||||
|
|||||||
Reference in New Issue
Block a user