Demo Update 6
This commit is contained in:
@@ -1,9 +1,13 @@
|
|||||||
import os
|
import os
|
||||||
import base64
|
import base64
|
||||||
import json
|
import json
|
||||||
|
import time
|
||||||
|
import math
|
||||||
|
import gc
|
||||||
|
import logging
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torchaudio
|
import torchaudio
|
||||||
import numpy as np
|
|
||||||
import whisperx
|
import whisperx
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from typing import List, Dict, Any, Optional
|
from typing import List, Dict, Any, Optional
|
||||||
@@ -11,290 +15,314 @@ from flask import Flask, request, send_from_directory, Response
|
|||||||
from flask_cors import CORS
|
from flask_cors import CORS
|
||||||
from flask_socketio import SocketIO, emit, disconnect
|
from flask_socketio import SocketIO, emit, disconnect
|
||||||
from generator import load_csm_1b, Segment
|
from generator import load_csm_1b, Segment
|
||||||
import time
|
|
||||||
import gc
|
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from threading import Lock
|
from threading import Lock
|
||||||
|
|
||||||
# Add this at the top of your file, replacing your current CUDA setup
|
# Configure logging
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||||
|
)
|
||||||
|
logger = logging.getLogger("sesame-server")
|
||||||
|
|
||||||
# CUDA setup with robust error handling
|
# CUDA Environment Setup
|
||||||
try:
|
def setup_cuda_environment():
|
||||||
# Handle CUDA issues
|
"""Set up CUDA environment with proper error handling"""
|
||||||
os.environ["CUDA_VISIBLE_DEVICES"] = "0" # Limit to first GPU only
|
# Search for CUDA libraries in common locations
|
||||||
|
cuda_lib_dirs = [
|
||||||
|
"/usr/local/cuda/lib64",
|
||||||
|
"/usr/lib/x86_64-linux-gnu",
|
||||||
|
"/usr/local/cuda/extras/CUPTI/lib64"
|
||||||
|
]
|
||||||
|
|
||||||
# Try enabling TF32 precision
|
# Add directories to LD_LIBRARY_PATH if they exist
|
||||||
try:
|
current_ld_path = os.environ.get('LD_LIBRARY_PATH', '')
|
||||||
torch.backends.cuda.matmul.allow_tf32 = True
|
for cuda_dir in cuda_lib_dirs:
|
||||||
torch.backends.cudnn.allow_tf32 = True
|
if os.path.exists(cuda_dir) and cuda_dir not in current_ld_path:
|
||||||
except:
|
if current_ld_path:
|
||||||
pass # Ignore if not supported
|
os.environ['LD_LIBRARY_PATH'] = f"{current_ld_path}:{cuda_dir}"
|
||||||
|
else:
|
||||||
|
os.environ['LD_LIBRARY_PATH'] = cuda_dir
|
||||||
|
current_ld_path = os.environ['LD_LIBRARY_PATH']
|
||||||
|
|
||||||
# Check if CUDA is available
|
logger.info(f"LD_LIBRARY_PATH set to: {os.environ.get('LD_LIBRARY_PATH', 'not set')}")
|
||||||
if torch.cuda.is_available():
|
|
||||||
try:
|
# Determine best compute device
|
||||||
# Test CUDA functionality
|
|
||||||
x = torch.rand(10, device="cuda")
|
|
||||||
y = x + x
|
|
||||||
del x, y
|
|
||||||
device = "cuda"
|
|
||||||
compute_type = "float16"
|
|
||||||
print("CUDA is fully functional")
|
|
||||||
except Exception as cuda_error:
|
|
||||||
print(f"CUDA is available but not working correctly: {str(cuda_error)}")
|
|
||||||
device = "cpu"
|
|
||||||
compute_type = "int8"
|
|
||||||
else:
|
|
||||||
device = "cpu"
|
|
||||||
compute_type = "int8"
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error setting up CUDA: {str(e)}")
|
|
||||||
device = "cpu"
|
device = "cpu"
|
||||||
compute_type = "int8"
|
compute_type = "int8"
|
||||||
|
|
||||||
print(f"Using device: {device} with compute type: {compute_type}")
|
try:
|
||||||
|
# Set CUDA preferences
|
||||||
|
os.environ["CUDA_VISIBLE_DEVICES"] = "0" # Limit to first GPU only
|
||||||
|
|
||||||
# Initialize the Sesame CSM model with robust error handling
|
# Try enabling TF32 precision if available
|
||||||
try:
|
|
||||||
print(f"Loading Sesame CSM model on {device}...")
|
|
||||||
generator = load_csm_1b(device=device)
|
|
||||||
print("Sesame CSM model loaded successfully")
|
|
||||||
except Exception as model_error:
|
|
||||||
print(f"Error loading Sesame CSM on {device}: {str(model_error)}")
|
|
||||||
if device == "cuda":
|
|
||||||
# Try on CPU as fallback
|
|
||||||
try:
|
try:
|
||||||
print("Trying to load Sesame CSM on CPU instead...")
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
device = "cpu" # Update global device setting
|
torch.backends.cudnn.allow_tf32 = True
|
||||||
generator = load_csm_1b(device="cpu")
|
torch.backends.cudnn.enabled = True
|
||||||
print("Sesame CSM model loaded on CPU successfully")
|
torch.backends.cudnn.benchmark = True
|
||||||
except Exception as cpu_error:
|
except Exception as e:
|
||||||
print(f"Fatal error - could not load Sesame CSM model: {str(cpu_error)}")
|
logger.warning(f"Could not set advanced CUDA options: {e}")
|
||||||
raise RuntimeError("Failed to load speech synthesis model")
|
|
||||||
else:
|
|
||||||
# Already tried CPU and it failed
|
|
||||||
raise RuntimeError("Failed to load speech synthesis model on any device")
|
|
||||||
|
|
||||||
# Replace the WhisperX model loading section
|
# Test if CUDA is functional
|
||||||
|
if torch.cuda.is_available():
|
||||||
# Initialize WhisperX for ASR with robust error handling
|
|
||||||
print("Loading WhisperX model...")
|
|
||||||
asr_model = None # Initialize to None first to avoid scope issues
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Always start with the tiny model on CPU for stability
|
|
||||||
asr_model = whisperx.load_model("tiny", "cpu", compute_type="int8")
|
|
||||||
print("WhisperX 'tiny' model loaded on CPU successfully")
|
|
||||||
|
|
||||||
# If CPU works, try CUDA if available
|
|
||||||
if device == "cuda":
|
|
||||||
try:
|
|
||||||
print("Trying to load WhisperX on CUDA...")
|
|
||||||
cuda_model = whisperx.load_model("tiny", "cuda", compute_type="float16")
|
|
||||||
# Test the model to ensure it works
|
|
||||||
test_audio = torch.zeros(16000) # 1 second of silence at 16kHz
|
|
||||||
_ = cuda_model.transcribe(test_audio.numpy(), batch_size=1)
|
|
||||||
# If we get here, CUDA works
|
|
||||||
asr_model = cuda_model
|
|
||||||
print("WhisperX model moved to CUDA successfully")
|
|
||||||
|
|
||||||
# Try to upgrade to small model on CUDA
|
|
||||||
try:
|
try:
|
||||||
small_model = whisperx.load_model("small", "cuda", compute_type="float16")
|
# Test basic CUDA operations
|
||||||
# Test it
|
x = torch.rand(10, device="cuda")
|
||||||
_ = small_model.transcribe(test_audio.numpy(), batch_size=1)
|
y = x + x
|
||||||
asr_model = small_model
|
del x, y
|
||||||
print("WhisperX 'small' model loaded on CUDA successfully")
|
torch.cuda.empty_cache()
|
||||||
except Exception as upgrade_error:
|
device = "cuda"
|
||||||
print(f"Staying with 'tiny' model on CUDA: {str(upgrade_error)}")
|
compute_type = "float16"
|
||||||
except Exception as cuda_error:
|
logger.info("CUDA is fully functional")
|
||||||
print(f"CUDA loading failed, staying with CPU model: {str(cuda_error)}")
|
except Exception as e:
|
||||||
except Exception as e:
|
logger.warning(f"CUDA available but not working correctly: {e}")
|
||||||
print(f"Error loading WhisperX model: {str(e)}")
|
device = "cpu"
|
||||||
# Create a minimal dummy model as last resort
|
else:
|
||||||
class DummyModel:
|
logger.info("CUDA is not available, using CPU")
|
||||||
def __init__(self):
|
except Exception as e:
|
||||||
self.device = "cpu"
|
logger.error(f"Error setting up computing environment: {e}")
|
||||||
def transcribe(self, *args, **kwargs):
|
|
||||||
return {"segments": [{"text": "Speech recognition currently unavailable."}]}
|
|
||||||
|
|
||||||
asr_model = DummyModel()
|
return device, compute_type
|
||||||
print("WARNING: Using dummy transcription model - ASR functionality limited")
|
|
||||||
|
|
||||||
# Silence detection parameters
|
# Set up the compute environment
|
||||||
SILENCE_THRESHOLD = 0.01 # Adjust based on your audio normalization
|
device, compute_type = setup_cuda_environment()
|
||||||
SILENCE_DURATION_SEC = 1.0 # How long silence must persist
|
|
||||||
|
|
||||||
# Define the base directory
|
# Constants and Configuration
|
||||||
|
SILENCE_THRESHOLD = 0.01
|
||||||
|
SILENCE_DURATION_SEC = 0.75
|
||||||
|
MAX_BUFFER_SIZE = 30 # Maximum chunks to buffer before processing
|
||||||
|
CHUNK_SIZE_MS = 500 # Size of audio chunks when streaming responses
|
||||||
|
|
||||||
|
# Define the base directory and static files directory
|
||||||
base_dir = os.path.dirname(os.path.abspath(__file__))
|
base_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
static_dir = os.path.join(base_dir, "static")
|
static_dir = os.path.join(base_dir, "static")
|
||||||
os.makedirs(static_dir, exist_ok=True)
|
os.makedirs(static_dir, exist_ok=True)
|
||||||
|
|
||||||
# Setup Flask
|
# Model Loading Functions
|
||||||
|
def load_speech_models():
|
||||||
|
"""Load all required speech models with fallbacks"""
|
||||||
|
# Load speech generation model (Sesame CSM)
|
||||||
|
try:
|
||||||
|
logger.info(f"Loading Sesame CSM model on {device}...")
|
||||||
|
generator = load_csm_1b(device=device)
|
||||||
|
logger.info("Sesame CSM model loaded successfully")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error loading Sesame CSM on {device}: {e}")
|
||||||
|
if device == "cuda":
|
||||||
|
try:
|
||||||
|
logger.info("Trying to load Sesame CSM on CPU instead...")
|
||||||
|
generator = load_csm_1b(device="cpu")
|
||||||
|
logger.info("Sesame CSM model loaded on CPU successfully")
|
||||||
|
except Exception as cpu_error:
|
||||||
|
logger.critical(f"Failed to load speech synthesis model: {cpu_error}")
|
||||||
|
raise RuntimeError("Failed to load speech synthesis model")
|
||||||
|
else:
|
||||||
|
raise RuntimeError("Failed to load speech synthesis model on any device")
|
||||||
|
|
||||||
|
# Load ASR model (WhisperX)
|
||||||
|
try:
|
||||||
|
logger.info("Loading WhisperX model...")
|
||||||
|
# Start with the tiny model on CPU for reliable initialization
|
||||||
|
asr_model = whisperx.load_model("tiny", "cpu", compute_type="int8")
|
||||||
|
logger.info("WhisperX 'tiny' model loaded on CPU successfully")
|
||||||
|
|
||||||
|
# Try upgrading to GPU if available
|
||||||
|
if device == "cuda":
|
||||||
|
try:
|
||||||
|
logger.info("Trying to load WhisperX on CUDA...")
|
||||||
|
# Test with a tiny model first
|
||||||
|
test_audio = torch.zeros(16000) # 1 second of silence
|
||||||
|
|
||||||
|
cuda_model = whisperx.load_model("tiny", "cuda", compute_type="float16")
|
||||||
|
# Test the model with real inference
|
||||||
|
_ = cuda_model.transcribe(test_audio.numpy(), batch_size=1)
|
||||||
|
asr_model = cuda_model
|
||||||
|
logger.info("WhisperX model running on CUDA successfully")
|
||||||
|
|
||||||
|
# Try to upgrade to small model
|
||||||
|
try:
|
||||||
|
small_model = whisperx.load_model("small", "cuda", compute_type="float16")
|
||||||
|
_ = small_model.transcribe(test_audio.numpy(), batch_size=1)
|
||||||
|
asr_model = small_model
|
||||||
|
logger.info("WhisperX 'small' model loaded on CUDA successfully")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Staying with 'tiny' model on CUDA: {e}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"CUDA loading failed, staying with CPU model: {e}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error loading WhisperX model: {e}")
|
||||||
|
# Create a minimal dummy model as last resort
|
||||||
|
class DummyModel:
|
||||||
|
def __init__(self):
|
||||||
|
self.device = "cpu"
|
||||||
|
def transcribe(self, *args, **kwargs):
|
||||||
|
return {"segments": [{"text": "Speech recognition currently unavailable."}]}
|
||||||
|
|
||||||
|
asr_model = DummyModel()
|
||||||
|
logger.warning("Using dummy transcription model - ASR functionality limited")
|
||||||
|
|
||||||
|
return generator, asr_model
|
||||||
|
|
||||||
|
# Load speech models
|
||||||
|
generator, asr_model = load_speech_models()
|
||||||
|
|
||||||
|
# Set up Flask and Socket.IO
|
||||||
app = Flask(__name__)
|
app = Flask(__name__)
|
||||||
CORS(app)
|
CORS(app)
|
||||||
socketio = SocketIO(app, cors_allowed_origins="*", async_mode='eventlet')
|
socketio = SocketIO(app, cors_allowed_origins="*", async_mode='eventlet')
|
||||||
|
|
||||||
# Socket connection management
|
# Socket connection management
|
||||||
thread = None
|
|
||||||
thread_lock = Lock()
|
thread_lock = Lock()
|
||||||
active_clients = {} # Map client_id to client context
|
active_clients = {} # Map client_id to client context
|
||||||
|
|
||||||
# Helper function to convert audio data
|
# Audio Utility Functions
|
||||||
def decode_audio_data(audio_data: str) -> torch.Tensor:
|
def decode_audio_data(audio_data: str) -> torch.Tensor:
|
||||||
"""Decode base64 audio data to a torch tensor with improved error handling"""
|
"""Decode base64 audio data to a torch tensor with improved error handling"""
|
||||||
try:
|
try:
|
||||||
# Skip empty audio data
|
# Skip empty audio data
|
||||||
if not audio_data or len(audio_data) < 100:
|
if not audio_data or len(audio_data) < 100:
|
||||||
print("Empty or too short audio data received")
|
logger.warning("Empty or too short audio data received")
|
||||||
return torch.zeros(generator.sample_rate // 2) # 0.5 seconds of silence
|
return torch.zeros(generator.sample_rate // 2) # 0.5 seconds of silence
|
||||||
|
|
||||||
# Extract the actual base64 content
|
# Extract the actual base64 content
|
||||||
if ',' in audio_data:
|
if ',' in audio_data:
|
||||||
# Handle data URL format (data:audio/wav;base64,...)
|
|
||||||
audio_data = audio_data.split(',')[1]
|
audio_data = audio_data.split(',')[1]
|
||||||
|
|
||||||
# Decode base64 audio data
|
# Decode base64 audio data
|
||||||
try:
|
try:
|
||||||
binary_data = base64.b64decode(audio_data)
|
binary_data = base64.b64decode(audio_data)
|
||||||
print(f"Decoded base64 data: {len(binary_data)} bytes")
|
logger.debug(f"Decoded base64 data: {len(binary_data)} bytes")
|
||||||
|
|
||||||
# Check if we have enough data for a valid WAV
|
# Check if we have enough data for a valid WAV
|
||||||
if len(binary_data) < 44: # WAV header is 44 bytes
|
if len(binary_data) < 44: # WAV header is 44 bytes
|
||||||
print("Data too small to be a valid WAV file")
|
logger.warning("Data too small to be a valid WAV file")
|
||||||
return torch.zeros(generator.sample_rate // 2)
|
return torch.zeros(generator.sample_rate // 2)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Base64 decoding error: {str(e)}")
|
logger.error(f"Base64 decoding error: {e}")
|
||||||
return torch.zeros(generator.sample_rate // 2)
|
return torch.zeros(generator.sample_rate // 2)
|
||||||
|
|
||||||
# Save for debugging
|
# Multiple approaches to handle audio data
|
||||||
debug_path = os.path.join(base_dir, "debug_incoming.wav")
|
audio_tensor = None
|
||||||
with open(debug_path, 'wb') as f:
|
sample_rate = None
|
||||||
f.write(binary_data)
|
|
||||||
print(f"Saved debug file: {debug_path}")
|
|
||||||
|
|
||||||
# Approach 1: Load directly with torchaudio
|
# Approach 1: Direct loading with torchaudio
|
||||||
try:
|
try:
|
||||||
with BytesIO(binary_data) as temp_file:
|
with BytesIO(binary_data) as temp_file:
|
||||||
temp_file.seek(0) # Ensure we're at the start of the buffer
|
temp_file.seek(0)
|
||||||
audio_tensor, sample_rate = torchaudio.load(temp_file, format="wav")
|
audio_tensor, sample_rate = torchaudio.load(temp_file, format="wav")
|
||||||
print(f"Direct loading success: shape={audio_tensor.shape}, rate={sample_rate}Hz")
|
logger.debug(f"Loaded audio: shape={audio_tensor.shape}, rate={sample_rate}Hz")
|
||||||
|
|
||||||
# Check if audio is valid
|
# Validate tensor
|
||||||
if audio_tensor.numel() == 0 or torch.isnan(audio_tensor).any():
|
if audio_tensor.numel() == 0 or torch.isnan(audio_tensor).any():
|
||||||
raise ValueError("Empty or invalid audio tensor detected")
|
raise ValueError("Invalid audio tensor")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Direct loading failed: {str(e)}")
|
logger.warning(f"Direct loading failed: {e}")
|
||||||
|
|
||||||
# Approach 2: Try to fix/normalize the WAV data
|
# Approach 2: Using wave module and numpy
|
||||||
try:
|
try:
|
||||||
# Sometimes WAV headers can be malformed, attempt to fix
|
temp_path = os.path.join(base_dir, f"temp_{time.time()}.wav")
|
||||||
temp_path = os.path.join(base_dir, "temp_fixing.wav")
|
|
||||||
with open(temp_path, 'wb') as f:
|
with open(temp_path, 'wb') as f:
|
||||||
f.write(binary_data)
|
f.write(binary_data)
|
||||||
|
|
||||||
# Use a simpler numpy approach as backup
|
|
||||||
import numpy as np
|
|
||||||
import wave
|
import wave
|
||||||
|
with wave.open(temp_path, 'rb') as wf:
|
||||||
|
n_channels = wf.getnchannels()
|
||||||
|
sample_width = wf.getsampwidth()
|
||||||
|
sample_rate = wf.getframerate()
|
||||||
|
n_frames = wf.getnframes()
|
||||||
|
frames = wf.readframes(n_frames)
|
||||||
|
|
||||||
try:
|
# Convert to numpy array
|
||||||
with wave.open(temp_path, 'rb') as wf:
|
if sample_width == 2: # 16-bit audio
|
||||||
n_channels = wf.getnchannels()
|
data = np.frombuffer(frames, dtype=np.int16).astype(np.float32) / 32768.0
|
||||||
sample_width = wf.getsampwidth()
|
elif sample_width == 1: # 8-bit audio
|
||||||
sample_rate = wf.getframerate()
|
data = np.frombuffer(frames, dtype=np.uint8).astype(np.float32) / 128.0 - 1.0
|
||||||
n_frames = wf.getnframes()
|
else:
|
||||||
|
raise ValueError(f"Unsupported sample width: {sample_width}")
|
||||||
|
|
||||||
# Read the frames
|
# Convert to mono if needed
|
||||||
frames = wf.readframes(n_frames)
|
if n_channels > 1:
|
||||||
print(f"Wave reading: channels={n_channels}, rate={sample_rate}Hz, frames={n_frames}")
|
data = data.reshape(-1, n_channels)
|
||||||
|
data = data.mean(axis=1)
|
||||||
|
|
||||||
# Convert to numpy and then to torch
|
# Convert to torch tensor
|
||||||
if sample_width == 2: # 16-bit audio
|
audio_tensor = torch.from_numpy(data)
|
||||||
data = np.frombuffer(frames, dtype=np.int16).astype(np.float32) / 32768.0
|
logger.info(f"Loaded audio using wave: shape={audio_tensor.shape}")
|
||||||
elif sample_width == 1: # 8-bit audio
|
|
||||||
data = np.frombuffer(frames, dtype=np.uint8).astype(np.float32) / 128.0 - 1.0
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported sample width: {sample_width}")
|
|
||||||
|
|
||||||
# Convert to mono if needed
|
# Clean up temp file
|
||||||
if n_channels > 1:
|
|
||||||
data = data.reshape(-1, n_channels)
|
|
||||||
data = data.mean(axis=1)
|
|
||||||
|
|
||||||
# Convert to torch tensor
|
|
||||||
audio_tensor = torch.from_numpy(data)
|
|
||||||
print(f"Successfully converted with numpy: shape={audio_tensor.shape}")
|
|
||||||
except Exception as wave_error:
|
|
||||||
print(f"Wave processing failed: {str(wave_error)}")
|
|
||||||
# Try with torchaudio as last resort
|
|
||||||
audio_tensor, sample_rate = torchaudio.load(temp_path, format="wav")
|
|
||||||
|
|
||||||
# Clean up
|
|
||||||
if os.path.exists(temp_path):
|
if os.path.exists(temp_path):
|
||||||
os.remove(temp_path)
|
os.remove(temp_path)
|
||||||
|
|
||||||
except Exception as e2:
|
except Exception as e2:
|
||||||
print(f"All WAV loading methods failed: {str(e2)}")
|
logger.error(f"All audio loading methods failed: {e2}")
|
||||||
print("Returning silence as fallback")
|
|
||||||
return torch.zeros(generator.sample_rate // 2)
|
return torch.zeros(generator.sample_rate // 2)
|
||||||
|
|
||||||
# Ensure audio is the right shape (mono)
|
# Format corrections
|
||||||
|
if audio_tensor is None:
|
||||||
|
return torch.zeros(generator.sample_rate // 2)
|
||||||
|
|
||||||
|
# Ensure audio is mono
|
||||||
if len(audio_tensor.shape) > 1 and audio_tensor.shape[0] > 1:
|
if len(audio_tensor.shape) > 1 and audio_tensor.shape[0] > 1:
|
||||||
audio_tensor = torch.mean(audio_tensor, dim=0)
|
audio_tensor = torch.mean(audio_tensor, dim=0)
|
||||||
|
|
||||||
# Ensure we have a 1D tensor
|
# Ensure 1D tensor
|
||||||
audio_tensor = audio_tensor.squeeze()
|
audio_tensor = audio_tensor.squeeze()
|
||||||
|
|
||||||
# Resample if needed
|
# Resample if needed
|
||||||
if sample_rate != generator.sample_rate:
|
if sample_rate != generator.sample_rate:
|
||||||
try:
|
try:
|
||||||
print(f"Resampling from {sample_rate}Hz to {generator.sample_rate}Hz")
|
logger.debug(f"Resampling from {sample_rate}Hz to {generator.sample_rate}Hz")
|
||||||
resampler = torchaudio.transforms.Resample(
|
resampler = torchaudio.transforms.Resample(
|
||||||
orig_freq=sample_rate,
|
orig_freq=sample_rate,
|
||||||
new_freq=generator.sample_rate
|
new_freq=generator.sample_rate
|
||||||
)
|
)
|
||||||
audio_tensor = resampler(audio_tensor)
|
audio_tensor = resampler(audio_tensor)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Resampling error: {str(e)}")
|
logger.warning(f"Resampling error: {e}")
|
||||||
# If resampling fails, just return the original audio
|
|
||||||
# The model can often handle different sample rates
|
|
||||||
|
|
||||||
# Normalize audio to avoid issues
|
# Normalize audio to avoid issues
|
||||||
if torch.abs(audio_tensor).max() > 0:
|
if torch.abs(audio_tensor).max() > 0:
|
||||||
audio_tensor = audio_tensor / torch.abs(audio_tensor).max()
|
audio_tensor = audio_tensor / torch.abs(audio_tensor).max()
|
||||||
|
|
||||||
print(f"Final audio tensor: shape={audio_tensor.shape}, min={audio_tensor.min().item():.4f}, max={audio_tensor.max().item():.4f}")
|
|
||||||
return audio_tensor
|
return audio_tensor
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Unhandled error in decode_audio_data: {str(e)}")
|
logger.error(f"Unhandled error in decode_audio_data: {e}")
|
||||||
# Return a small silent audio segment as fallback
|
return torch.zeros(generator.sample_rate // 2)
|
||||||
return torch.zeros(generator.sample_rate // 2) # 0.5 seconds of silence
|
|
||||||
|
|
||||||
|
|
||||||
def encode_audio_data(audio_tensor: torch.Tensor) -> str:
|
def encode_audio_data(audio_tensor: torch.Tensor) -> str:
|
||||||
"""Encode torch tensor audio to base64 string"""
|
"""Encode torch tensor audio to base64 string"""
|
||||||
buf = BytesIO()
|
try:
|
||||||
torchaudio.save(buf, audio_tensor.unsqueeze(0).cpu(), generator.sample_rate, format="wav")
|
buf = BytesIO()
|
||||||
buf.seek(0)
|
torchaudio.save(buf, audio_tensor.unsqueeze(0).cpu(), generator.sample_rate, format="wav")
|
||||||
audio_base64 = base64.b64encode(buf.read()).decode('utf-8')
|
buf.seek(0)
|
||||||
return f"data:audio/wav;base64,{audio_base64}"
|
audio_base64 = base64.b64encode(buf.read()).decode('utf-8')
|
||||||
|
return f"data:audio/wav;base64,{audio_base64}"
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error encoding audio: {e}")
|
||||||
|
# Return a minimal silent audio file
|
||||||
|
silence = torch.zeros(generator.sample_rate // 2).unsqueeze(0)
|
||||||
|
buf = BytesIO()
|
||||||
|
torchaudio.save(buf, silence, generator.sample_rate, format="wav")
|
||||||
|
buf.seek(0)
|
||||||
|
return f"data:audio/wav;base64,{base64.b64encode(buf.read()).decode('utf-8')}"
|
||||||
|
|
||||||
def transcribe_audio(audio_tensor: torch.Tensor) -> str:
|
def transcribe_audio(audio_tensor: torch.Tensor) -> str:
|
||||||
"""Transcribe audio using WhisperX with robust error handling"""
|
"""Transcribe audio using WhisperX with robust error handling"""
|
||||||
global asr_model # Declare global at the beginning of the function
|
global asr_model
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Save the tensor to a temporary file
|
# Save the tensor to a temporary file
|
||||||
temp_path = os.path.join(base_dir, "temp_audio.wav")
|
temp_path = os.path.join(base_dir, f"temp_audio_{time.time()}.wav")
|
||||||
torchaudio.save(temp_path, audio_tensor.unsqueeze(0).cpu(), generator.sample_rate)
|
torchaudio.save(temp_path, audio_tensor.unsqueeze(0).cpu(), generator.sample_rate)
|
||||||
|
|
||||||
print(f"Transcribing audio file: {temp_path} (size: {os.path.getsize(temp_path)} bytes)")
|
logger.info(f"Transcribing audio file: {os.path.getsize(temp_path)} bytes")
|
||||||
|
|
||||||
# Load the audio file using whisperx's function
|
# Load the audio for WhisperX
|
||||||
try:
|
try:
|
||||||
audio = whisperx.load_audio(temp_path)
|
audio = whisperx.load_audio(temp_path)
|
||||||
except Exception as audio_load_error:
|
except Exception as e:
|
||||||
print(f"WhisperX load_audio failed: {str(audio_load_error)}")
|
logger.warning(f"WhisperX load_audio failed: {e}")
|
||||||
# Fall back to manual loading
|
# Fall back to manual loading
|
||||||
import soundfile as sf
|
import soundfile as sf
|
||||||
audio, sr = sf.read(temp_path)
|
audio, sr = sf.read(temp_path)
|
||||||
@@ -302,59 +330,55 @@ def transcribe_audio(audio_tensor: torch.Tensor) -> str:
|
|||||||
from scipy import signal
|
from scipy import signal
|
||||||
audio = signal.resample(audio, int(len(audio) * 16000 / sr))
|
audio = signal.resample(audio, int(len(audio) * 16000 / sr))
|
||||||
|
|
||||||
# Transcribe with error handling for CUDA issues
|
# Transcribe with error handling
|
||||||
try:
|
try:
|
||||||
# Try with original device
|
result = asr_model.transcribe(audio, batch_size=4)
|
||||||
result = asr_model.transcribe(audio, batch_size=8)
|
except RuntimeError as e:
|
||||||
except RuntimeError as cuda_error:
|
if "CUDA" in str(e) or "libcudnn" in str(e):
|
||||||
if "CUDA" in str(cuda_error) or "libcudnn" in str(cuda_error):
|
logger.warning(f"CUDA error in transcription, falling back to CPU: {e}")
|
||||||
print(f"CUDA error in transcription, falling back to CPU: {str(cuda_error)}")
|
|
||||||
|
|
||||||
# Try to load a CPU model as fallback
|
|
||||||
try:
|
try:
|
||||||
# Move model to CPU and try again
|
# Try CPU model
|
||||||
asr_model = whisperx.load_model("tiny", "cpu", compute_type="int8")
|
cpu_model = whisperx.load_model("tiny", "cpu", compute_type="int8")
|
||||||
result = asr_model.transcribe(audio, batch_size=1)
|
result = cpu_model.transcribe(audio, batch_size=1)
|
||||||
except Exception as e:
|
# Update the global model if the original one is broken
|
||||||
print(f"CPU fallback also failed: {str(e)}")
|
asr_model = cpu_model
|
||||||
|
except Exception as cpu_e:
|
||||||
|
logger.error(f"CPU fallback failed: {cpu_e}")
|
||||||
return "I'm having trouble processing audio right now."
|
return "I'm having trouble processing audio right now."
|
||||||
else:
|
else:
|
||||||
# Re-raise if it's not a CUDA error
|
|
||||||
raise
|
raise
|
||||||
|
finally:
|
||||||
|
# Clean up
|
||||||
|
if os.path.exists(temp_path):
|
||||||
|
os.remove(temp_path)
|
||||||
|
|
||||||
# Clean up
|
# Extract text from segments
|
||||||
|
if result["segments"] and len(result["segments"]) > 0:
|
||||||
|
transcription = " ".join([segment["text"] for segment in result["segments"]])
|
||||||
|
logger.info(f"Transcription: '{transcription.strip()}'")
|
||||||
|
return transcription.strip()
|
||||||
|
|
||||||
|
return ""
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in transcription: {e}")
|
||||||
if os.path.exists(temp_path):
|
if os.path.exists(temp_path):
|
||||||
os.remove(temp_path)
|
os.remove(temp_path)
|
||||||
|
|
||||||
# Get the transcription text
|
|
||||||
if result["segments"] and len(result["segments"]) > 0:
|
|
||||||
# Combine all segments
|
|
||||||
transcription = " ".join([segment["text"] for segment in result["segments"]])
|
|
||||||
print(f"Transcription successful: '{transcription.strip()}'")
|
|
||||||
return transcription.strip()
|
|
||||||
else:
|
|
||||||
print("Transcription returned no segments")
|
|
||||||
return ""
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error in transcription: {str(e)}")
|
|
||||||
import traceback
|
|
||||||
traceback.print_exc()
|
|
||||||
if os.path.exists("temp_audio.wav"):
|
|
||||||
os.remove("temp_audio.wav")
|
|
||||||
return "I heard something but couldn't understand it."
|
return "I heard something but couldn't understand it."
|
||||||
|
|
||||||
|
|
||||||
def generate_response(text: str, conversation_history: List[Segment]) -> str:
|
def generate_response(text: str, conversation_history: List[Segment]) -> str:
|
||||||
"""Generate a contextual response based on the transcribed text"""
|
"""Generate a contextual response based on the transcribed text"""
|
||||||
# Simple response logic - can be replaced with a more sophisticated LLM in the future
|
# Simple response logic - can be replaced with a more sophisticated LLM
|
||||||
responses = {
|
responses = {
|
||||||
"hello": "Hello there! How are you doing today?",
|
"hello": "Hello there! How can I help you today?",
|
||||||
|
"hi": "Hi there! What can I do for you?",
|
||||||
"how are you": "I'm doing well, thanks for asking! How about you?",
|
"how are you": "I'm doing well, thanks for asking! How about you?",
|
||||||
"what is your name": "I'm Sesame, your voice assistant. How can I help you?",
|
"what is your name": "I'm Sesame, your voice assistant. How can I help you?",
|
||||||
|
"who are you": "I'm Sesame, an AI voice assistant. I'm here to chat with you!",
|
||||||
"bye": "Goodbye! It was nice chatting with you.",
|
"bye": "Goodbye! It was nice chatting with you.",
|
||||||
"thank you": "You're welcome! Is there anything else I can help with?",
|
"thank you": "You're welcome! Is there anything else I can help with?",
|
||||||
"weather": "I don't have real-time weather data, but I hope it's nice where you are!",
|
"weather": "I don't have real-time weather data, but I hope it's nice where you are!",
|
||||||
"help": "I can chat with you using natural voice. Just speak normally and I'll respond.",
|
"help": "I can chat with you using natural voice. Just speak normally and I'll respond.",
|
||||||
|
"what can you do": "I can have a conversation with you, answer questions, and provide assistance with various topics.",
|
||||||
}
|
}
|
||||||
|
|
||||||
text_lower = text.lower()
|
text_lower = text.lower()
|
||||||
@@ -372,7 +396,7 @@ def generate_response(text: str, conversation_history: List[Segment]) -> str:
|
|||||||
else:
|
else:
|
||||||
return f"I understand you said '{text}'. That's interesting! Can you tell me more about that?"
|
return f"I understand you said '{text}'. That's interesting! Can you tell me more about that?"
|
||||||
|
|
||||||
# Flask routes for serving static content
|
# Flask Routes
|
||||||
@app.route('/')
|
@app.route('/')
|
||||||
def index():
|
def index():
|
||||||
return send_from_directory(base_dir, 'index.html')
|
return send_from_directory(base_dir, 'index.html')
|
||||||
@@ -391,11 +415,11 @@ def voice_chat_js():
|
|||||||
def serve_static(path):
|
def serve_static(path):
|
||||||
return send_from_directory(static_dir, path)
|
return send_from_directory(static_dir, path)
|
||||||
|
|
||||||
# Socket.IO event handlers
|
# Socket.IO Event Handlers
|
||||||
@socketio.on('connect')
|
@socketio.on('connect')
|
||||||
def handle_connect():
|
def handle_connect():
|
||||||
client_id = request.sid
|
client_id = request.sid
|
||||||
print(f"Client connected: {client_id}")
|
logger.info(f"Client connected: {client_id}")
|
||||||
|
|
||||||
# Initialize client context
|
# Initialize client context
|
||||||
active_clients[client_id] = {
|
active_clients[client_id] = {
|
||||||
@@ -414,7 +438,7 @@ def handle_disconnect():
|
|||||||
client_id = request.sid
|
client_id = request.sid
|
||||||
if client_id in active_clients:
|
if client_id in active_clients:
|
||||||
del active_clients[client_id]
|
del active_clients[client_id]
|
||||||
print(f"Client disconnected: {client_id}")
|
logger.info(f"Client disconnected: {client_id}")
|
||||||
|
|
||||||
@socketio.on('generate')
|
@socketio.on('generate')
|
||||||
def handle_generate(data):
|
def handle_generate(data):
|
||||||
@@ -427,7 +451,7 @@ def handle_generate(data):
|
|||||||
text = data.get('text', '')
|
text = data.get('text', '')
|
||||||
speaker_id = data.get('speaker', 0)
|
speaker_id = data.get('speaker', 0)
|
||||||
|
|
||||||
print(f"Generating audio for: '{text}' with speaker {speaker_id}")
|
logger.info(f"Generating audio for: '{text}' with speaker {speaker_id}")
|
||||||
|
|
||||||
# Generate audio response
|
# Generate audio response
|
||||||
audio_tensor = generator.generate(
|
audio_tensor = generator.generate(
|
||||||
@@ -446,11 +470,12 @@ def handle_generate(data):
|
|||||||
audio_base64 = encode_audio_data(audio_tensor)
|
audio_base64 = encode_audio_data(audio_tensor)
|
||||||
emit('audio_response', {
|
emit('audio_response', {
|
||||||
'type': 'audio_response',
|
'type': 'audio_response',
|
||||||
'audio': audio_base64
|
'audio': audio_base64,
|
||||||
|
'text': text
|
||||||
})
|
})
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error generating audio: {str(e)}")
|
logger.error(f"Error generating audio: {e}")
|
||||||
emit('error', {
|
emit('error', {
|
||||||
'type': 'error',
|
'type': 'error',
|
||||||
'message': f"Error generating audio: {str(e)}"
|
'message': f"Error generating audio: {str(e)}"
|
||||||
@@ -482,7 +507,7 @@ def handle_add_to_context(data):
|
|||||||
})
|
})
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error adding to context: {str(e)}")
|
logger.error(f"Error adding to context: {e}")
|
||||||
emit('error', {
|
emit('error', {
|
||||||
'type': 'error',
|
'type': 'error',
|
||||||
'message': f"Error processing audio: {str(e)}"
|
'message': f"Error processing audio: {str(e)}"
|
||||||
@@ -512,6 +537,11 @@ def handle_stream_audio(data):
|
|||||||
speaker_id = data.get('speaker', 0)
|
speaker_id = data.get('speaker', 0)
|
||||||
audio_data = data.get('audio', '')
|
audio_data = data.get('audio', '')
|
||||||
|
|
||||||
|
# Skip if no audio data (might be just a connection test)
|
||||||
|
if not audio_data:
|
||||||
|
logger.debug("Empty audio data received, ignoring")
|
||||||
|
return
|
||||||
|
|
||||||
# Convert received audio to tensor
|
# Convert received audio to tensor
|
||||||
audio_chunk = decode_audio_data(audio_data)
|
audio_chunk = decode_audio_data(audio_data)
|
||||||
|
|
||||||
@@ -522,7 +552,7 @@ def handle_stream_audio(data):
|
|||||||
client['energy_window'].clear()
|
client['energy_window'].clear()
|
||||||
client['is_silence'] = False
|
client['is_silence'] = False
|
||||||
client['last_active_time'] = time.time()
|
client['last_active_time'] = time.time()
|
||||||
print(f"[{client_id}] Streaming started with speaker ID: {speaker_id}")
|
logger.info(f"[{client_id[:8]}] Streaming started with speaker ID: {speaker_id}")
|
||||||
emit('streaming_status', {
|
emit('streaming_status', {
|
||||||
'type': 'streaming_status',
|
'type': 'streaming_status',
|
||||||
'status': 'started'
|
'status': 'started'
|
||||||
@@ -553,52 +583,74 @@ def handle_stream_audio(data):
|
|||||||
|
|
||||||
if client['is_silence'] and silence_elapsed >= SILENCE_DURATION_SEC and len(client['streaming_buffer']) > 0:
|
if client['is_silence'] and silence_elapsed >= SILENCE_DURATION_SEC and len(client['streaming_buffer']) > 0:
|
||||||
# User has stopped talking - process the collected audio
|
# User has stopped talking - process the collected audio
|
||||||
print(f"[{client_id}] Processing audio after {silence_elapsed:.2f}s of silence")
|
logger.info(f"[{client_id[:8]}] Processing audio after {silence_elapsed:.2f}s of silence")
|
||||||
|
process_complete_utterance(client_id, client, speaker_id)
|
||||||
|
|
||||||
full_audio = torch.cat(client['streaming_buffer'], dim=0)
|
# If buffer gets too large without silence, process it anyway
|
||||||
|
elif len(client['streaming_buffer']) >= MAX_BUFFER_SIZE:
|
||||||
|
logger.info(f"[{client_id[:8]}] Processing long audio segment without silence")
|
||||||
|
process_complete_utterance(client_id, client, speaker_id, is_incomplete=True)
|
||||||
|
|
||||||
# Process with WhisperX speech-to-text
|
# Keep half of the buffer for context (sliding window approach)
|
||||||
print(f"[{client_id}] Starting transcription with WhisperX...")
|
half_point = len(client['streaming_buffer']) // 2
|
||||||
transcribed_text = transcribe_audio(full_audio)
|
client['streaming_buffer'] = client['streaming_buffer'][half_point:]
|
||||||
|
|
||||||
# Log the transcription
|
except Exception as e:
|
||||||
print(f"[{client_id}] Transcribed text: '{transcribed_text}'")
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
logger.error(f"Error processing streaming audio: {e}")
|
||||||
|
emit('error', {
|
||||||
|
'type': 'error',
|
||||||
|
'message': f"Error processing streaming audio: {str(e)}"
|
||||||
|
})
|
||||||
|
|
||||||
# Handle the transcription result
|
def process_complete_utterance(client_id, client, speaker_id, is_incomplete=False):
|
||||||
if transcribed_text:
|
"""Process a complete utterance (after silence or buffer limit)"""
|
||||||
# Add user message to context
|
try:
|
||||||
user_segment = Segment(text=transcribed_text, speaker=speaker_id, audio=full_audio)
|
# Combine audio chunks
|
||||||
client['context_segments'].append(user_segment)
|
full_audio = torch.cat(client['streaming_buffer'], dim=0)
|
||||||
|
|
||||||
# Send the transcribed text to client
|
# Process with speech-to-text
|
||||||
emit('transcription', {
|
logger.info(f"[{client_id[:8]}] Starting transcription...")
|
||||||
'type': 'transcription',
|
transcribed_text = transcribe_audio(full_audio)
|
||||||
'text': transcribed_text
|
|
||||||
})
|
|
||||||
|
|
||||||
|
# Add suffix for incomplete utterances
|
||||||
|
if is_incomplete:
|
||||||
|
transcribed_text += " (processing continued speech...)"
|
||||||
|
|
||||||
|
# Log the transcription
|
||||||
|
logger.info(f"[{client_id[:8]}] Transcribed: '{transcribed_text}'")
|
||||||
|
|
||||||
|
# Handle the transcription result
|
||||||
|
if transcribed_text:
|
||||||
|
# Add user message to context
|
||||||
|
user_segment = Segment(text=transcribed_text, speaker=speaker_id, audio=full_audio)
|
||||||
|
client['context_segments'].append(user_segment)
|
||||||
|
|
||||||
|
# Send the transcribed text to client
|
||||||
|
emit('transcription', {
|
||||||
|
'type': 'transcription',
|
||||||
|
'text': transcribed_text
|
||||||
|
}, room=client_id)
|
||||||
|
|
||||||
|
# Only generate a response if this is a complete utterance
|
||||||
|
if not is_incomplete:
|
||||||
# Generate a contextual response
|
# Generate a contextual response
|
||||||
response_text = generate_response(transcribed_text, client['context_segments'])
|
response_text = generate_response(transcribed_text, client['context_segments'])
|
||||||
print(f"[{client_id}] Generating audio response: '{response_text}'")
|
logger.info(f"[{client_id[:8]}] Generating response: '{response_text}'")
|
||||||
|
|
||||||
# Let the client know we're processing
|
# Let the client know we're processing
|
||||||
emit('processing_status', {
|
emit('processing_status', {
|
||||||
'type': 'processing_status',
|
'type': 'processing_status',
|
||||||
'status': 'generating_audio',
|
'status': 'generating_audio',
|
||||||
'message': 'Generating audio response...'
|
'message': 'Generating audio response...'
|
||||||
})
|
}, room=client_id)
|
||||||
|
|
||||||
# Generate audio for the response
|
# Generate audio for the response
|
||||||
try:
|
try:
|
||||||
# Use a different speaker than the user
|
# Use a different speaker than the user
|
||||||
ai_speaker_id = 1 if speaker_id == 0 else 0
|
ai_speaker_id = 1 if speaker_id == 0 else 0
|
||||||
|
|
||||||
# Start audio generation with streaming (chunk by chunk)
|
|
||||||
audio_chunks = []
|
|
||||||
|
|
||||||
# This version tries to stream the audio generation in smaller chunks
|
|
||||||
# Note: CSM model doesn't natively support incremental generation,
|
|
||||||
# so we're simulating it here for a more responsive UI experience
|
|
||||||
|
|
||||||
# Generate the full response
|
# Generate the full response
|
||||||
audio_tensor = generator.generate(
|
audio_tensor = generator.generate(
|
||||||
text=response_text,
|
text=response_text,
|
||||||
@@ -621,60 +673,37 @@ def handle_stream_audio(data):
|
|||||||
'type': 'audio_response',
|
'type': 'audio_response',
|
||||||
'text': response_text,
|
'text': response_text,
|
||||||
'audio': audio_base64
|
'audio': audio_base64
|
||||||
})
|
}, room=client_id)
|
||||||
|
|
||||||
print(f"[{client_id}] Audio response sent: {len(audio_base64)} bytes")
|
logger.info(f"[{client_id[:8]}] Audio response sent")
|
||||||
|
|
||||||
except Exception as gen_error:
|
except Exception as e:
|
||||||
print(f"Error generating audio response: {str(gen_error)}")
|
logger.error(f"Error generating audio response: {e}")
|
||||||
emit('error', {
|
emit('error', {
|
||||||
'type': 'error',
|
'type': 'error',
|
||||||
'message': "Sorry, there was an error generating the audio response."
|
'message': "Sorry, there was an error generating the audio response."
|
||||||
})
|
}, room=client_id)
|
||||||
else:
|
else:
|
||||||
# If transcription failed, send a generic response
|
# If transcription failed, send a notification
|
||||||
emit('error', {
|
emit('error', {
|
||||||
'type': 'error',
|
'type': 'error',
|
||||||
'message': "Sorry, I couldn't understand what you said. Could you try again?"
|
'message': "Sorry, I couldn't understand what you said. Could you try again?"
|
||||||
})
|
}, room=client_id)
|
||||||
|
|
||||||
# Clear buffer and reset silence detection
|
# Only clear buffer for complete utterances
|
||||||
|
if not is_incomplete:
|
||||||
|
# Reset state
|
||||||
client['streaming_buffer'] = []
|
client['streaming_buffer'] = []
|
||||||
client['energy_window'].clear()
|
client['energy_window'].clear()
|
||||||
client['is_silence'] = False
|
client['is_silence'] = False
|
||||||
client['last_active_time'] = time.time()
|
client['last_active_time'] = time.time()
|
||||||
|
|
||||||
# If buffer gets too large without silence, process it anyway
|
|
||||||
elif len(client['streaming_buffer']) >= 30: # ~6 seconds of audio at 5 chunks/sec
|
|
||||||
print(f"[{client_id}] Processing long audio segment without silence")
|
|
||||||
full_audio = torch.cat(client['streaming_buffer'], dim=0)
|
|
||||||
|
|
||||||
# Process with WhisperX speech-to-text
|
|
||||||
transcribed_text = transcribe_audio(full_audio)
|
|
||||||
|
|
||||||
if transcribed_text:
|
|
||||||
client['context_segments'].append(
|
|
||||||
Segment(text=transcribed_text, speaker=speaker_id, audio=full_audio)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Send the transcribed text to client
|
|
||||||
emit('transcription', {
|
|
||||||
'type': 'transcription',
|
|
||||||
'text': transcribed_text + " (processing continued speech...)"
|
|
||||||
})
|
|
||||||
|
|
||||||
# Keep half of the buffer for context (sliding window approach)
|
|
||||||
half_point = len(client['streaming_buffer']) // 2
|
|
||||||
client['streaming_buffer'] = client['streaming_buffer'][half_point:]
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
import traceback
|
logger.error(f"Error processing utterance: {e}")
|
||||||
traceback.print_exc()
|
|
||||||
print(f"Error processing streaming audio: {str(e)}")
|
|
||||||
emit('error', {
|
emit('error', {
|
||||||
'type': 'error',
|
'type': 'error',
|
||||||
'message': f"Error processing streaming audio: {str(e)}"
|
'message': f"Error processing audio: {str(e)}"
|
||||||
})
|
}, room=client_id)
|
||||||
|
|
||||||
@socketio.on('stop_streaming')
|
@socketio.on('stop_streaming')
|
||||||
def handle_stop_streaming(data):
|
def handle_stop_streaming(data):
|
||||||
@@ -687,21 +716,8 @@ def handle_stop_streaming(data):
|
|||||||
|
|
||||||
if client['streaming_buffer'] and len(client['streaming_buffer']) > 5:
|
if client['streaming_buffer'] and len(client['streaming_buffer']) > 5:
|
||||||
# Process any remaining audio in the buffer
|
# Process any remaining audio in the buffer
|
||||||
full_audio = torch.cat(client['streaming_buffer'], dim=0)
|
logger.info(f"[{client_id[:8]}] Processing final audio buffer on stop")
|
||||||
|
process_complete_utterance(client_id, client, data.get("speaker", 0))
|
||||||
# Process with WhisperX speech-to-text
|
|
||||||
transcribed_text = transcribe_audio(full_audio)
|
|
||||||
|
|
||||||
if transcribed_text:
|
|
||||||
client['context_segments'].append(
|
|
||||||
Segment(text=transcribed_text, speaker=data.get("speaker", 0), audio=full_audio)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Send the transcribed text to client
|
|
||||||
emit('transcription', {
|
|
||||||
'type': 'transcription',
|
|
||||||
'text': transcribed_text
|
|
||||||
})
|
|
||||||
|
|
||||||
client['streaming_buffer'] = []
|
client['streaming_buffer'] = []
|
||||||
emit('streaming_status', {
|
emit('streaming_status', {
|
||||||
@@ -709,18 +725,18 @@ def handle_stop_streaming(data):
|
|||||||
'status': 'stopped'
|
'status': 'stopped'
|
||||||
})
|
})
|
||||||
|
|
||||||
def stream_audio_to_client(client_id, audio_tensor, text, speaker_id, chunk_size_ms=500):
|
def stream_audio_to_client(client_id, audio_tensor, text, speaker_id, chunk_size_ms=CHUNK_SIZE_MS):
|
||||||
"""Stream audio to client in chunks to simulate real-time generation"""
|
"""Stream audio to client in chunks to simulate real-time generation"""
|
||||||
try:
|
try:
|
||||||
if client_id not in active_clients:
|
if client_id not in active_clients:
|
||||||
print(f"Client {client_id} not found for streaming")
|
logger.warning(f"Client {client_id} not found for streaming")
|
||||||
return
|
return
|
||||||
|
|
||||||
# Calculate chunk size in samples
|
# Calculate chunk size in samples
|
||||||
chunk_size = int(generator.sample_rate * chunk_size_ms / 1000)
|
chunk_size = int(generator.sample_rate * chunk_size_ms / 1000)
|
||||||
total_chunks = math.ceil(audio_tensor.size(0) / chunk_size)
|
total_chunks = math.ceil(audio_tensor.size(0) / chunk_size)
|
||||||
|
|
||||||
print(f"Streaming audio in {total_chunks} chunks of {chunk_size_ms}ms each")
|
logger.info(f"Streaming audio in {total_chunks} chunks of {chunk_size_ms}ms each")
|
||||||
|
|
||||||
# Send initial response with text but no audio yet
|
# Send initial response with text but no audio yet
|
||||||
socketio.emit('audio_response_start', {
|
socketio.emit('audio_response_start', {
|
||||||
@@ -758,29 +774,24 @@ def stream_audio_to_client(client_id, audio_tensor, text, speaker_id, chunk_size
|
|||||||
'text': text
|
'text': text
|
||||||
}, room=client_id)
|
}, room=client_id)
|
||||||
|
|
||||||
print(f"Audio streaming complete: {total_chunks} chunks sent")
|
logger.info(f"Audio streaming complete: {total_chunks} chunks sent")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error streaming audio to client: {str(e)}")
|
logger.error(f"Error streaming audio to client: {e}")
|
||||||
import traceback
|
import traceback
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
|
||||||
|
# Main server start
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
print(f"\n{'='*60}")
|
print(f"\n{'='*60}")
|
||||||
print(f"🔊 Sesame AI Voice Chat Server (Flask Implementation)")
|
print(f"🔊 Sesame AI Voice Chat Server")
|
||||||
print(f"{'='*60}")
|
print(f"{'='*60}")
|
||||||
print(f"📡 Server Information:")
|
print(f"📡 Server Information:")
|
||||||
print(f" - Local URL: http://localhost:5000")
|
print(f" - Local URL: http://localhost:5000")
|
||||||
print(f" - Network URL: http://<your-ip-address>:5000")
|
print(f" - Network URL: http://<your-ip-address>:5000")
|
||||||
print(f" - WebSocket: ws://<your-ip-address>:5000/socket.io")
|
|
||||||
print(f"{'='*60}")
|
|
||||||
print(f"💡 To make this server public:")
|
|
||||||
print(f" 1. Ensure port 5000 is open in your firewall")
|
|
||||||
print(f" 2. Set up port forwarding on your router to port 5000")
|
|
||||||
print(f" 3. Or use a service like ngrok with: ngrok http 5000")
|
|
||||||
print(f"{'='*60}")
|
print(f"{'='*60}")
|
||||||
print(f"🌐 Device: {device.upper()}")
|
print(f"🌐 Device: {device.upper()}")
|
||||||
print(f"🧠 Models loaded: Sesame CSM + WhisperX ({asr_model.device})")
|
print(f"🧠 Models: Sesame CSM + WhisperX ASR")
|
||||||
print(f"🔧 Serving from: {os.path.join(base_dir, 'index.html')}")
|
print(f"🔧 Serving from: {os.path.join(base_dir, 'index.html')}")
|
||||||
print(f"{'='*60}")
|
print(f"{'='*60}")
|
||||||
print(f"Ready to receive connections! Press Ctrl+C to stop the server.\n")
|
print(f"Ready to receive connections! Press Ctrl+C to stop the server.\n")
|
||||||
|
|||||||
Reference in New Issue
Block a user