Merge branch 'main' of https://github.com/GamerBoss101/HooHacks-12
This commit is contained in:
1068
Backend/index.html
1068
Backend/index.html
File diff suppressed because it is too large
Load Diff
@@ -1,99 +1,276 @@
|
|||||||
import os
|
import os
|
||||||
import base64
|
import base64
|
||||||
import json
|
import json
|
||||||
import asyncio
|
|
||||||
import torch
|
import torch
|
||||||
import torchaudio
|
import torchaudio
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import io
|
|
||||||
import whisperx
|
import whisperx
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from typing import List, Dict, Any, Optional
|
from typing import List, Dict, Any, Optional
|
||||||
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
from flask import Flask, request, send_from_directory, Response
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from flask_cors import CORS
|
||||||
from pydantic import BaseModel
|
from flask_socketio import SocketIO, emit, disconnect
|
||||||
from generator import load_csm_1b, Segment
|
from generator import load_csm_1b, Segment
|
||||||
import uvicorn
|
|
||||||
import time
|
import time
|
||||||
import gc
|
import gc
|
||||||
from collections import deque
|
from collections import deque
|
||||||
|
from threading import Lock
|
||||||
|
|
||||||
# Select device
|
# Add this at the top of your file, replacing your current CUDA setup
|
||||||
|
|
||||||
|
# CUDA setup with robust error handling
|
||||||
|
try:
|
||||||
|
# Handle CUDA issues
|
||||||
|
os.environ["CUDA_VISIBLE_DEVICES"] = "0" # Limit to first GPU only
|
||||||
|
|
||||||
|
# Try enabling TF32 precision
|
||||||
|
try:
|
||||||
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
|
torch.backends.cudnn.allow_tf32 = True
|
||||||
|
except:
|
||||||
|
pass # Ignore if not supported
|
||||||
|
|
||||||
|
# Check if CUDA is available
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
|
try:
|
||||||
|
# Test CUDA functionality
|
||||||
|
x = torch.rand(10, device="cuda")
|
||||||
|
y = x + x
|
||||||
|
del x, y
|
||||||
device = "cuda"
|
device = "cuda"
|
||||||
|
compute_type = "float16"
|
||||||
|
print("CUDA is fully functional")
|
||||||
|
except Exception as cuda_error:
|
||||||
|
print(f"CUDA is available but not working correctly: {str(cuda_error)}")
|
||||||
|
device = "cpu"
|
||||||
|
compute_type = "int8"
|
||||||
else:
|
else:
|
||||||
device = "cpu"
|
device = "cpu"
|
||||||
print(f"Using device: {device}")
|
compute_type = "int8"
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error setting up CUDA: {str(e)}")
|
||||||
|
device = "cpu"
|
||||||
|
compute_type = "int8"
|
||||||
|
|
||||||
# Initialize the model
|
print(f"Using device: {device} with compute type: {compute_type}")
|
||||||
|
|
||||||
|
# Initialize the Sesame CSM model with robust error handling
|
||||||
|
try:
|
||||||
|
print(f"Loading Sesame CSM model on {device}...")
|
||||||
generator = load_csm_1b(device=device)
|
generator = load_csm_1b(device=device)
|
||||||
|
print("Sesame CSM model loaded successfully")
|
||||||
|
except Exception as model_error:
|
||||||
|
print(f"Error loading Sesame CSM on {device}: {str(model_error)}")
|
||||||
|
if device == "cuda":
|
||||||
|
# Try on CPU as fallback
|
||||||
|
try:
|
||||||
|
print("Trying to load Sesame CSM on CPU instead...")
|
||||||
|
device = "cpu" # Update global device setting
|
||||||
|
generator = load_csm_1b(device="cpu")
|
||||||
|
print("Sesame CSM model loaded on CPU successfully")
|
||||||
|
except Exception as cpu_error:
|
||||||
|
print(f"Fatal error - could not load Sesame CSM model: {str(cpu_error)}")
|
||||||
|
raise RuntimeError("Failed to load speech synthesis model")
|
||||||
|
else:
|
||||||
|
# Already tried CPU and it failed
|
||||||
|
raise RuntimeError("Failed to load speech synthesis model on any device")
|
||||||
|
|
||||||
# Initialize WhisperX for ASR
|
# Replace the WhisperX model loading section
|
||||||
|
|
||||||
|
# Initialize WhisperX for ASR with robust error handling
|
||||||
print("Loading WhisperX model...")
|
print("Loading WhisperX model...")
|
||||||
# Use a smaller model for faster response times
|
asr_model = None # Initialize to None first to avoid scope issues
|
||||||
asr_model = whisperx.load_model("medium", device, compute_type="float16")
|
|
||||||
print("WhisperX model loaded!")
|
|
||||||
|
|
||||||
app = FastAPI()
|
try:
|
||||||
|
# Always start with the tiny model on CPU for stability
|
||||||
|
asr_model = whisperx.load_model("tiny", "cpu", compute_type="int8")
|
||||||
|
print("WhisperX 'tiny' model loaded on CPU successfully")
|
||||||
|
|
||||||
# Add CORS middleware to allow cross-origin requests
|
# If CPU works, try CUDA if available
|
||||||
app.add_middleware(
|
if device == "cuda":
|
||||||
CORSMiddleware,
|
try:
|
||||||
allow_origins=["*"], # Allow all origins in development
|
print("Trying to load WhisperX on CUDA...")
|
||||||
allow_credentials=True,
|
cuda_model = whisperx.load_model("tiny", "cuda", compute_type="float16")
|
||||||
allow_methods=["*"],
|
# Test the model to ensure it works
|
||||||
allow_headers=["*"],
|
test_audio = torch.zeros(16000) # 1 second of silence at 16kHz
|
||||||
)
|
_ = cuda_model.transcribe(test_audio.numpy(), batch_size=1)
|
||||||
|
# If we get here, CUDA works
|
||||||
|
asr_model = cuda_model
|
||||||
|
print("WhisperX model moved to CUDA successfully")
|
||||||
|
|
||||||
# Connection manager to handle multiple clients
|
# Try to upgrade to small model on CUDA
|
||||||
class ConnectionManager:
|
try:
|
||||||
|
small_model = whisperx.load_model("small", "cuda", compute_type="float16")
|
||||||
|
# Test it
|
||||||
|
_ = small_model.transcribe(test_audio.numpy(), batch_size=1)
|
||||||
|
asr_model = small_model
|
||||||
|
print("WhisperX 'small' model loaded on CUDA successfully")
|
||||||
|
except Exception as upgrade_error:
|
||||||
|
print(f"Staying with 'tiny' model on CUDA: {str(upgrade_error)}")
|
||||||
|
except Exception as cuda_error:
|
||||||
|
print(f"CUDA loading failed, staying with CPU model: {str(cuda_error)}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error loading WhisperX model: {str(e)}")
|
||||||
|
# Create a minimal dummy model as last resort
|
||||||
|
class DummyModel:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.active_connections: List[WebSocket] = []
|
self.device = "cpu"
|
||||||
|
def transcribe(self, *args, **kwargs):
|
||||||
|
return {"segments": [{"text": "Speech recognition currently unavailable."}]}
|
||||||
|
|
||||||
async def connect(self, websocket: WebSocket):
|
asr_model = DummyModel()
|
||||||
await websocket.accept()
|
print("WARNING: Using dummy transcription model - ASR functionality limited")
|
||||||
self.active_connections.append(websocket)
|
|
||||||
|
|
||||||
def disconnect(self, websocket: WebSocket):
|
|
||||||
self.active_connections.remove(websocket)
|
|
||||||
|
|
||||||
manager = ConnectionManager()
|
|
||||||
|
|
||||||
# Silence detection parameters
|
# Silence detection parameters
|
||||||
SILENCE_THRESHOLD = 0.01 # Adjust based on your audio normalization
|
SILENCE_THRESHOLD = 0.01 # Adjust based on your audio normalization
|
||||||
SILENCE_DURATION_SEC = 1.0 # How long silence must persist to be considered "stopped talking"
|
SILENCE_DURATION_SEC = 1.0 # How long silence must persist
|
||||||
|
|
||||||
|
# Define the base directory
|
||||||
|
base_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
static_dir = os.path.join(base_dir, "static")
|
||||||
|
os.makedirs(static_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# Setup Flask
|
||||||
|
app = Flask(__name__)
|
||||||
|
CORS(app)
|
||||||
|
socketio = SocketIO(app, cors_allowed_origins="*", async_mode='eventlet')
|
||||||
|
|
||||||
|
# Socket connection management
|
||||||
|
thread = None
|
||||||
|
thread_lock = Lock()
|
||||||
|
active_clients = {} # Map client_id to client context
|
||||||
|
|
||||||
# Helper function to convert audio data
|
# Helper function to convert audio data
|
||||||
async def decode_audio_data(audio_data: str) -> torch.Tensor:
|
def decode_audio_data(audio_data: str) -> torch.Tensor:
|
||||||
"""Decode base64 audio data to a torch tensor"""
|
"""Decode base64 audio data to a torch tensor with improved error handling"""
|
||||||
try:
|
try:
|
||||||
|
# Skip empty audio data
|
||||||
|
if not audio_data or len(audio_data) < 100:
|
||||||
|
print("Empty or too short audio data received")
|
||||||
|
return torch.zeros(generator.sample_rate // 2) # 0.5 seconds of silence
|
||||||
|
|
||||||
|
# Extract the actual base64 content
|
||||||
|
if ',' in audio_data:
|
||||||
|
# Handle data URL format (data:audio/wav;base64,...)
|
||||||
|
audio_data = audio_data.split(',')[1]
|
||||||
|
|
||||||
# Decode base64 audio data
|
# Decode base64 audio data
|
||||||
binary_data = base64.b64decode(audio_data.split(',')[1] if ',' in audio_data else audio_data)
|
try:
|
||||||
|
binary_data = base64.b64decode(audio_data)
|
||||||
|
print(f"Decoded base64 data: {len(binary_data)} bytes")
|
||||||
|
|
||||||
# Save to a temporary WAV file first
|
# Check if we have enough data for a valid WAV
|
||||||
temp_file = BytesIO(binary_data)
|
if len(binary_data) < 44: # WAV header is 44 bytes
|
||||||
|
print("Data too small to be a valid WAV file")
|
||||||
|
return torch.zeros(generator.sample_rate // 2)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Base64 decoding error: {str(e)}")
|
||||||
|
return torch.zeros(generator.sample_rate // 2)
|
||||||
|
|
||||||
# Load audio from binary data, explicitly specifying the format
|
# Save for debugging
|
||||||
|
debug_path = os.path.join(base_dir, "debug_incoming.wav")
|
||||||
|
with open(debug_path, 'wb') as f:
|
||||||
|
f.write(binary_data)
|
||||||
|
print(f"Saved debug file: {debug_path}")
|
||||||
|
|
||||||
|
# Approach 1: Load directly with torchaudio
|
||||||
|
try:
|
||||||
|
with BytesIO(binary_data) as temp_file:
|
||||||
|
temp_file.seek(0) # Ensure we're at the start of the buffer
|
||||||
audio_tensor, sample_rate = torchaudio.load(temp_file, format="wav")
|
audio_tensor, sample_rate = torchaudio.load(temp_file, format="wav")
|
||||||
|
print(f"Direct loading success: shape={audio_tensor.shape}, rate={sample_rate}Hz")
|
||||||
|
|
||||||
|
# Check if audio is valid
|
||||||
|
if audio_tensor.numel() == 0 or torch.isnan(audio_tensor).any():
|
||||||
|
raise ValueError("Empty or invalid audio tensor detected")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Direct loading failed: {str(e)}")
|
||||||
|
|
||||||
|
# Approach 2: Try to fix/normalize the WAV data
|
||||||
|
try:
|
||||||
|
# Sometimes WAV headers can be malformed, attempt to fix
|
||||||
|
temp_path = os.path.join(base_dir, "temp_fixing.wav")
|
||||||
|
with open(temp_path, 'wb') as f:
|
||||||
|
f.write(binary_data)
|
||||||
|
|
||||||
|
# Use a simpler numpy approach as backup
|
||||||
|
import numpy as np
|
||||||
|
import wave
|
||||||
|
|
||||||
|
try:
|
||||||
|
with wave.open(temp_path, 'rb') as wf:
|
||||||
|
n_channels = wf.getnchannels()
|
||||||
|
sample_width = wf.getsampwidth()
|
||||||
|
sample_rate = wf.getframerate()
|
||||||
|
n_frames = wf.getnframes()
|
||||||
|
|
||||||
|
# Read the frames
|
||||||
|
frames = wf.readframes(n_frames)
|
||||||
|
print(f"Wave reading: channels={n_channels}, rate={sample_rate}Hz, frames={n_frames}")
|
||||||
|
|
||||||
|
# Convert to numpy and then to torch
|
||||||
|
if sample_width == 2: # 16-bit audio
|
||||||
|
data = np.frombuffer(frames, dtype=np.int16).astype(np.float32) / 32768.0
|
||||||
|
elif sample_width == 1: # 8-bit audio
|
||||||
|
data = np.frombuffer(frames, dtype=np.uint8).astype(np.float32) / 128.0 - 1.0
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported sample width: {sample_width}")
|
||||||
|
|
||||||
|
# Convert to mono if needed
|
||||||
|
if n_channels > 1:
|
||||||
|
data = data.reshape(-1, n_channels)
|
||||||
|
data = data.mean(axis=1)
|
||||||
|
|
||||||
|
# Convert to torch tensor
|
||||||
|
audio_tensor = torch.from_numpy(data)
|
||||||
|
print(f"Successfully converted with numpy: shape={audio_tensor.shape}")
|
||||||
|
except Exception as wave_error:
|
||||||
|
print(f"Wave processing failed: {str(wave_error)}")
|
||||||
|
# Try with torchaudio as last resort
|
||||||
|
audio_tensor, sample_rate = torchaudio.load(temp_path, format="wav")
|
||||||
|
|
||||||
|
# Clean up
|
||||||
|
if os.path.exists(temp_path):
|
||||||
|
os.remove(temp_path)
|
||||||
|
except Exception as e2:
|
||||||
|
print(f"All WAV loading methods failed: {str(e2)}")
|
||||||
|
print("Returning silence as fallback")
|
||||||
|
return torch.zeros(generator.sample_rate // 2)
|
||||||
|
|
||||||
|
# Ensure audio is the right shape (mono)
|
||||||
|
if len(audio_tensor.shape) > 1 and audio_tensor.shape[0] > 1:
|
||||||
|
audio_tensor = torch.mean(audio_tensor, dim=0)
|
||||||
|
|
||||||
|
# Ensure we have a 1D tensor
|
||||||
|
audio_tensor = audio_tensor.squeeze()
|
||||||
|
|
||||||
# Resample if needed
|
# Resample if needed
|
||||||
if sample_rate != generator.sample_rate:
|
if sample_rate != generator.sample_rate:
|
||||||
audio_tensor = torchaudio.functional.resample(
|
try:
|
||||||
audio_tensor.squeeze(0),
|
print(f"Resampling from {sample_rate}Hz to {generator.sample_rate}Hz")
|
||||||
|
resampler = torchaudio.transforms.Resample(
|
||||||
orig_freq=sample_rate,
|
orig_freq=sample_rate,
|
||||||
new_freq=generator.sample_rate
|
new_freq=generator.sample_rate
|
||||||
)
|
)
|
||||||
else:
|
audio_tensor = resampler(audio_tensor)
|
||||||
audio_tensor = audio_tensor.squeeze(0)
|
except Exception as e:
|
||||||
|
print(f"Resampling error: {str(e)}")
|
||||||
|
# If resampling fails, just return the original audio
|
||||||
|
# The model can often handle different sample rates
|
||||||
|
|
||||||
|
# Normalize audio to avoid issues
|
||||||
|
if torch.abs(audio_tensor).max() > 0:
|
||||||
|
audio_tensor = audio_tensor / torch.abs(audio_tensor).max()
|
||||||
|
|
||||||
|
print(f"Final audio tensor: shape={audio_tensor.shape}, min={audio_tensor.min().item():.4f}, max={audio_tensor.max().item():.4f}")
|
||||||
return audio_tensor
|
return audio_tensor
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error decoding audio: {str(e)}")
|
print(f"Unhandled error in decode_audio_data: {str(e)}")
|
||||||
# Return a small silent audio segment as fallback
|
# Return a small silent audio segment as fallback
|
||||||
return torch.zeros(generator.sample_rate // 2) # 0.5 seconds of silence
|
return torch.zeros(generator.sample_rate // 2) # 0.5 seconds of silence
|
||||||
|
|
||||||
|
|
||||||
async def encode_audio_data(audio_tensor: torch.Tensor) -> str:
|
def encode_audio_data(audio_tensor: torch.Tensor) -> str:
|
||||||
"""Encode torch tensor audio to base64 string"""
|
"""Encode torch tensor audio to base64 string"""
|
||||||
buf = BytesIO()
|
buf = BytesIO()
|
||||||
torchaudio.save(buf, audio_tensor.unsqueeze(0).cpu(), generator.sample_rate, format="wav")
|
torchaudio.save(buf, audio_tensor.unsqueeze(0).cpu(), generator.sample_rate, format="wav")
|
||||||
@@ -102,40 +279,72 @@ async def encode_audio_data(audio_tensor: torch.Tensor) -> str:
|
|||||||
return f"data:audio/wav;base64,{audio_base64}"
|
return f"data:audio/wav;base64,{audio_base64}"
|
||||||
|
|
||||||
|
|
||||||
async def transcribe_audio(audio_tensor: torch.Tensor) -> str:
|
def transcribe_audio(audio_tensor: torch.Tensor) -> str:
|
||||||
"""Transcribe audio using WhisperX"""
|
"""Transcribe audio using WhisperX with robust error handling"""
|
||||||
|
global asr_model # Declare global at the beginning of the function
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Save the tensor to a temporary file
|
# Save the tensor to a temporary file
|
||||||
temp_file = BytesIO()
|
temp_path = os.path.join(base_dir, "temp_audio.wav")
|
||||||
torchaudio.save(temp_file, audio_tensor.unsqueeze(0).cpu(), generator.sample_rate, format="wav")
|
torchaudio.save(temp_path, audio_tensor.unsqueeze(0).cpu(), generator.sample_rate)
|
||||||
temp_file.seek(0)
|
|
||||||
|
|
||||||
# Create a temporary file on disk (WhisperX requires a file path)
|
print(f"Transcribing audio file: {temp_path} (size: {os.path.getsize(temp_path)} bytes)")
|
||||||
temp_path = "temp_audio.wav"
|
|
||||||
with open(temp_path, "wb") as f:
|
|
||||||
f.write(temp_file.read())
|
|
||||||
|
|
||||||
# Load and transcribe the audio
|
# Load the audio file using whisperx's function
|
||||||
|
try:
|
||||||
audio = whisperx.load_audio(temp_path)
|
audio = whisperx.load_audio(temp_path)
|
||||||
result = asr_model.transcribe(audio, batch_size=16)
|
except Exception as audio_load_error:
|
||||||
|
print(f"WhisperX load_audio failed: {str(audio_load_error)}")
|
||||||
|
# Fall back to manual loading
|
||||||
|
import soundfile as sf
|
||||||
|
audio, sr = sf.read(temp_path)
|
||||||
|
if sr != 16000: # WhisperX expects 16kHz audio
|
||||||
|
from scipy import signal
|
||||||
|
audio = signal.resample(audio, int(len(audio) * 16000 / sr))
|
||||||
|
|
||||||
|
# Transcribe with error handling for CUDA issues
|
||||||
|
try:
|
||||||
|
# Try with original device
|
||||||
|
result = asr_model.transcribe(audio, batch_size=8)
|
||||||
|
except RuntimeError as cuda_error:
|
||||||
|
if "CUDA" in str(cuda_error) or "libcudnn" in str(cuda_error):
|
||||||
|
print(f"CUDA error in transcription, falling back to CPU: {str(cuda_error)}")
|
||||||
|
|
||||||
|
# Try to load a CPU model as fallback
|
||||||
|
try:
|
||||||
|
# Move model to CPU and try again
|
||||||
|
asr_model = whisperx.load_model("tiny", "cpu", compute_type="int8")
|
||||||
|
result = asr_model.transcribe(audio, batch_size=1)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"CPU fallback also failed: {str(e)}")
|
||||||
|
return "I'm having trouble processing audio right now."
|
||||||
|
else:
|
||||||
|
# Re-raise if it's not a CUDA error
|
||||||
|
raise
|
||||||
|
|
||||||
# Clean up
|
# Clean up
|
||||||
|
if os.path.exists(temp_path):
|
||||||
os.remove(temp_path)
|
os.remove(temp_path)
|
||||||
|
|
||||||
# Get the transcription text
|
# Get the transcription text
|
||||||
if result["segments"] and len(result["segments"]) > 0:
|
if result["segments"] and len(result["segments"]) > 0:
|
||||||
# Combine all segments
|
# Combine all segments
|
||||||
transcription = " ".join([segment["text"] for segment in result["segments"]])
|
transcription = " ".join([segment["text"] for segment in result["segments"]])
|
||||||
print(f"Transcription: {transcription}")
|
print(f"Transcription successful: '{transcription.strip()}'")
|
||||||
return transcription.strip()
|
return transcription.strip()
|
||||||
else:
|
else:
|
||||||
|
print("Transcription returned no segments")
|
||||||
return ""
|
return ""
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error in transcription: {str(e)}")
|
print(f"Error in transcription: {str(e)}")
|
||||||
return ""
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
if os.path.exists("temp_audio.wav"):
|
||||||
|
os.remove("temp_audio.wav")
|
||||||
|
return "I heard something but couldn't understand it."
|
||||||
|
|
||||||
|
|
||||||
async def generate_response(text: str, conversation_history: List[Segment]) -> str:
|
def generate_response(text: str, conversation_history: List[Segment]) -> str:
|
||||||
"""Generate a contextual response based on the transcribed text"""
|
"""Generate a contextual response based on the transcribed text"""
|
||||||
# Simple response logic - can be replaced with a more sophisticated LLM in the future
|
# Simple response logic - can be replaced with a more sophisticated LLM in the future
|
||||||
responses = {
|
responses = {
|
||||||
@@ -163,255 +372,417 @@ async def generate_response(text: str, conversation_history: List[Segment]) -> s
|
|||||||
else:
|
else:
|
||||||
return f"I understand you said '{text}'. That's interesting! Can you tell me more about that?"
|
return f"I understand you said '{text}'. That's interesting! Can you tell me more about that?"
|
||||||
|
|
||||||
|
# Flask routes for serving static content
|
||||||
|
@app.route('/')
|
||||||
|
def index():
|
||||||
|
return send_from_directory(base_dir, 'index.html')
|
||||||
|
|
||||||
@app.websocket("/ws")
|
@app.route('/favicon.ico')
|
||||||
async def websocket_endpoint(websocket: WebSocket):
|
def favicon():
|
||||||
await manager.connect(websocket)
|
if os.path.exists(os.path.join(static_dir, 'favicon.ico')):
|
||||||
context_segments = [] # Store conversation context
|
return send_from_directory(static_dir, 'favicon.ico')
|
||||||
streaming_buffer = [] # Buffer for streaming audio chunks
|
return Response(status=204)
|
||||||
is_streaming = False
|
|
||||||
|
|
||||||
# Variables for silence detection
|
@app.route('/voice-chat.js')
|
||||||
last_active_time = time.time()
|
def voice_chat_js():
|
||||||
is_silence = False
|
return send_from_directory(base_dir, 'voice-chat.js')
|
||||||
energy_window = deque(maxlen=10) # For tracking recent audio energy
|
|
||||||
|
@app.route('/static/<path:path>')
|
||||||
|
def serve_static(path):
|
||||||
|
return send_from_directory(static_dir, path)
|
||||||
|
|
||||||
|
# Socket.IO event handlers
|
||||||
|
@socketio.on('connect')
|
||||||
|
def handle_connect():
|
||||||
|
client_id = request.sid
|
||||||
|
print(f"Client connected: {client_id}")
|
||||||
|
|
||||||
|
# Initialize client context
|
||||||
|
active_clients[client_id] = {
|
||||||
|
'context_segments': [],
|
||||||
|
'streaming_buffer': [],
|
||||||
|
'is_streaming': False,
|
||||||
|
'is_silence': False,
|
||||||
|
'last_active_time': time.time(),
|
||||||
|
'energy_window': deque(maxlen=10)
|
||||||
|
}
|
||||||
|
|
||||||
|
emit('status', {'type': 'connected', 'message': 'Connected to server'})
|
||||||
|
|
||||||
|
@socketio.on('disconnect')
|
||||||
|
def handle_disconnect():
|
||||||
|
client_id = request.sid
|
||||||
|
if client_id in active_clients:
|
||||||
|
del active_clients[client_id]
|
||||||
|
print(f"Client disconnected: {client_id}")
|
||||||
|
|
||||||
|
@socketio.on('generate')
|
||||||
|
def handle_generate(data):
|
||||||
|
client_id = request.sid
|
||||||
|
if client_id not in active_clients:
|
||||||
|
emit('error', {'message': 'Client not registered'})
|
||||||
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
while True:
|
text = data.get('text', '')
|
||||||
# Receive JSON data from client
|
speaker_id = data.get('speaker', 0)
|
||||||
data = await websocket.receive_text()
|
|
||||||
request = json.loads(data)
|
|
||||||
|
|
||||||
action = request.get("action")
|
print(f"Generating audio for: '{text}' with speaker {speaker_id}")
|
||||||
|
|
||||||
if action == "generate":
|
|
||||||
try:
|
|
||||||
text = request.get("text", "")
|
|
||||||
speaker_id = request.get("speaker", 0)
|
|
||||||
|
|
||||||
# Generate audio response
|
# Generate audio response
|
||||||
print(f"Generating audio for: '{text}' with speaker {speaker_id}")
|
|
||||||
audio_tensor = generator.generate(
|
audio_tensor = generator.generate(
|
||||||
text=text,
|
text=text,
|
||||||
speaker=speaker_id,
|
speaker=speaker_id,
|
||||||
context=context_segments,
|
context=active_clients[client_id]['context_segments'],
|
||||||
max_audio_length_ms=10_000,
|
max_audio_length_ms=10_000,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add to conversation context
|
# Add to conversation context
|
||||||
context_segments.append(Segment(text=text, speaker=speaker_id, audio=audio_tensor))
|
active_clients[client_id]['context_segments'].append(
|
||||||
|
Segment(text=text, speaker=speaker_id, audio=audio_tensor)
|
||||||
|
)
|
||||||
|
|
||||||
# Convert audio to base64 and send back to client
|
# Convert audio to base64 and send back to client
|
||||||
audio_base64 = await encode_audio_data(audio_tensor)
|
audio_base64 = encode_audio_data(audio_tensor)
|
||||||
await websocket.send_json({
|
emit('audio_response', {
|
||||||
"type": "audio_response",
|
'type': 'audio_response',
|
||||||
"audio": audio_base64
|
'audio': audio_base64
|
||||||
})
|
})
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error generating audio: {str(e)}")
|
print(f"Error generating audio: {str(e)}")
|
||||||
await websocket.send_json({
|
emit('error', {
|
||||||
"type": "error",
|
'type': 'error',
|
||||||
"message": f"Error generating audio: {str(e)}"
|
'message': f"Error generating audio: {str(e)}"
|
||||||
})
|
})
|
||||||
|
|
||||||
elif action == "add_to_context":
|
@socketio.on('add_to_context')
|
||||||
|
def handle_add_to_context(data):
|
||||||
|
client_id = request.sid
|
||||||
|
if client_id not in active_clients:
|
||||||
|
emit('error', {'message': 'Client not registered'})
|
||||||
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
text = request.get("text", "")
|
text = data.get('text', '')
|
||||||
speaker_id = request.get("speaker", 0)
|
speaker_id = data.get('speaker', 0)
|
||||||
audio_data = request.get("audio", "")
|
audio_data = data.get('audio', '')
|
||||||
|
|
||||||
# Convert received audio to tensor
|
# Convert received audio to tensor
|
||||||
audio_tensor = await decode_audio_data(audio_data)
|
audio_tensor = decode_audio_data(audio_data)
|
||||||
|
|
||||||
# Add to conversation context
|
# Add to conversation context
|
||||||
context_segments.append(Segment(text=text, speaker=speaker_id, audio=audio_tensor))
|
active_clients[client_id]['context_segments'].append(
|
||||||
|
Segment(text=text, speaker=speaker_id, audio=audio_tensor)
|
||||||
|
)
|
||||||
|
|
||||||
await websocket.send_json({
|
emit('context_updated', {
|
||||||
"type": "context_updated",
|
'type': 'context_updated',
|
||||||
"message": "Audio added to context"
|
'message': 'Audio added to context'
|
||||||
})
|
})
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error adding to context: {str(e)}")
|
print(f"Error adding to context: {str(e)}")
|
||||||
await websocket.send_json({
|
emit('error', {
|
||||||
"type": "error",
|
'type': 'error',
|
||||||
"message": f"Error processing audio: {str(e)}"
|
'message': f"Error processing audio: {str(e)}"
|
||||||
})
|
})
|
||||||
|
|
||||||
elif action == "clear_context":
|
@socketio.on('clear_context')
|
||||||
context_segments = []
|
def handle_clear_context():
|
||||||
await websocket.send_json({
|
client_id = request.sid
|
||||||
"type": "context_updated",
|
if client_id in active_clients:
|
||||||
"message": "Context cleared"
|
active_clients[client_id]['context_segments'] = []
|
||||||
|
|
||||||
|
emit('context_updated', {
|
||||||
|
'type': 'context_updated',
|
||||||
|
'message': 'Context cleared'
|
||||||
})
|
})
|
||||||
|
|
||||||
elif action == "stream_audio":
|
@socketio.on('stream_audio')
|
||||||
|
def handle_stream_audio(data):
|
||||||
|
client_id = request.sid
|
||||||
|
if client_id not in active_clients:
|
||||||
|
emit('error', {'message': 'Client not registered'})
|
||||||
|
return
|
||||||
|
|
||||||
|
client = active_clients[client_id]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
speaker_id = request.get("speaker", 0)
|
speaker_id = data.get('speaker', 0)
|
||||||
audio_data = request.get("audio", "")
|
audio_data = data.get('audio', '')
|
||||||
|
|
||||||
# Convert received audio to tensor
|
# Convert received audio to tensor
|
||||||
audio_chunk = await decode_audio_data(audio_data)
|
audio_chunk = decode_audio_data(audio_data)
|
||||||
|
|
||||||
# Start streaming mode if not already started
|
# Start streaming mode if not already started
|
||||||
if not is_streaming:
|
if not client['is_streaming']:
|
||||||
is_streaming = True
|
client['is_streaming'] = True
|
||||||
streaming_buffer = []
|
client['streaming_buffer'] = []
|
||||||
energy_window.clear()
|
client['energy_window'].clear()
|
||||||
is_silence = False
|
client['is_silence'] = False
|
||||||
last_active_time = time.time()
|
client['last_active_time'] = time.time()
|
||||||
await websocket.send_json({
|
print(f"[{client_id}] Streaming started with speaker ID: {speaker_id}")
|
||||||
"type": "streaming_status",
|
emit('streaming_status', {
|
||||||
"status": "started"
|
'type': 'streaming_status',
|
||||||
|
'status': 'started'
|
||||||
})
|
})
|
||||||
|
|
||||||
# Calculate audio energy for silence detection
|
# Calculate audio energy for silence detection
|
||||||
chunk_energy = torch.mean(torch.abs(audio_chunk)).item()
|
chunk_energy = torch.mean(torch.abs(audio_chunk)).item()
|
||||||
energy_window.append(chunk_energy)
|
client['energy_window'].append(chunk_energy)
|
||||||
avg_energy = sum(energy_window) / len(energy_window)
|
avg_energy = sum(client['energy_window']) / len(client['energy_window'])
|
||||||
|
|
||||||
# Check if audio is silent
|
# Check if audio is silent
|
||||||
current_silence = avg_energy < SILENCE_THRESHOLD
|
current_silence = avg_energy < SILENCE_THRESHOLD
|
||||||
|
|
||||||
# Track silence transition
|
# Track silence transition
|
||||||
if not is_silence and current_silence:
|
if not client['is_silence'] and current_silence:
|
||||||
# Transition to silence
|
# Transition to silence
|
||||||
is_silence = True
|
client['is_silence'] = True
|
||||||
last_active_time = time.time()
|
client['last_active_time'] = time.time()
|
||||||
elif is_silence and not current_silence:
|
elif client['is_silence'] and not current_silence:
|
||||||
# User started talking again
|
# User started talking again
|
||||||
is_silence = False
|
client['is_silence'] = False
|
||||||
|
|
||||||
# Add chunk to buffer regardless of silence state
|
# Add chunk to buffer regardless of silence state
|
||||||
streaming_buffer.append(audio_chunk)
|
client['streaming_buffer'].append(audio_chunk)
|
||||||
|
|
||||||
# Check if silence has persisted long enough to consider "stopped talking"
|
# Check if silence has persisted long enough to consider "stopped talking"
|
||||||
silence_elapsed = time.time() - last_active_time
|
silence_elapsed = time.time() - client['last_active_time']
|
||||||
|
|
||||||
if is_silence and silence_elapsed >= SILENCE_DURATION_SEC and len(streaming_buffer) > 0:
|
if client['is_silence'] and silence_elapsed >= SILENCE_DURATION_SEC and len(client['streaming_buffer']) > 0:
|
||||||
# User has stopped talking - process the collected audio
|
# User has stopped talking - process the collected audio
|
||||||
full_audio = torch.cat(streaming_buffer, dim=0)
|
print(f"[{client_id}] Processing audio after {silence_elapsed:.2f}s of silence")
|
||||||
|
|
||||||
|
full_audio = torch.cat(client['streaming_buffer'], dim=0)
|
||||||
|
|
||||||
# Process with WhisperX speech-to-text
|
# Process with WhisperX speech-to-text
|
||||||
transcribed_text = await transcribe_audio(full_audio)
|
print(f"[{client_id}] Starting transcription with WhisperX...")
|
||||||
|
transcribed_text = transcribe_audio(full_audio)
|
||||||
|
|
||||||
# Log the transcription
|
# Log the transcription
|
||||||
print(f"Transcribed text: '{transcribed_text}'")
|
print(f"[{client_id}] Transcribed text: '{transcribed_text}'")
|
||||||
|
|
||||||
# Add to conversation context
|
# Handle the transcription result
|
||||||
if transcribed_text:
|
if transcribed_text:
|
||||||
|
# Add user message to context
|
||||||
user_segment = Segment(text=transcribed_text, speaker=speaker_id, audio=full_audio)
|
user_segment = Segment(text=transcribed_text, speaker=speaker_id, audio=full_audio)
|
||||||
context_segments.append(user_segment)
|
client['context_segments'].append(user_segment)
|
||||||
|
|
||||||
# Generate a contextual response
|
|
||||||
response_text = await generate_response(transcribed_text, context_segments)
|
|
||||||
|
|
||||||
# Send the transcribed text to client
|
# Send the transcribed text to client
|
||||||
await websocket.send_json({
|
emit('transcription', {
|
||||||
"type": "transcription",
|
'type': 'transcription',
|
||||||
"text": transcribed_text
|
'text': transcribed_text
|
||||||
|
})
|
||||||
|
|
||||||
|
# Generate a contextual response
|
||||||
|
response_text = generate_response(transcribed_text, client['context_segments'])
|
||||||
|
print(f"[{client_id}] Generating audio response: '{response_text}'")
|
||||||
|
|
||||||
|
# Let the client know we're processing
|
||||||
|
emit('processing_status', {
|
||||||
|
'type': 'processing_status',
|
||||||
|
'status': 'generating_audio',
|
||||||
|
'message': 'Generating audio response...'
|
||||||
})
|
})
|
||||||
|
|
||||||
# Generate audio for the response
|
# Generate audio for the response
|
||||||
|
try:
|
||||||
|
# Use a different speaker than the user
|
||||||
|
ai_speaker_id = 1 if speaker_id == 0 else 0
|
||||||
|
|
||||||
|
# Start audio generation with streaming (chunk by chunk)
|
||||||
|
audio_chunks = []
|
||||||
|
|
||||||
|
# This version tries to stream the audio generation in smaller chunks
|
||||||
|
# Note: CSM model doesn't natively support incremental generation,
|
||||||
|
# so we're simulating it here for a more responsive UI experience
|
||||||
|
|
||||||
|
# Generate the full response
|
||||||
audio_tensor = generator.generate(
|
audio_tensor = generator.generate(
|
||||||
text=response_text,
|
text=response_text,
|
||||||
speaker=1 if speaker_id == 0 else 0, # Use opposite speaker
|
speaker=ai_speaker_id,
|
||||||
context=context_segments,
|
context=client['context_segments'],
|
||||||
max_audio_length_ms=10_000,
|
max_audio_length_ms=10_000,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add response to context
|
# Add response to context
|
||||||
ai_segment = Segment(
|
ai_segment = Segment(
|
||||||
text=response_text,
|
text=response_text,
|
||||||
speaker=1 if speaker_id == 0 else 0,
|
speaker=ai_speaker_id,
|
||||||
audio=audio_tensor
|
audio=audio_tensor
|
||||||
)
|
)
|
||||||
context_segments.append(ai_segment)
|
client['context_segments'].append(ai_segment)
|
||||||
|
|
||||||
# Convert audio to base64 and send back to client
|
# Convert audio to base64 and send back to client
|
||||||
audio_base64 = await encode_audio_data(audio_tensor)
|
audio_base64 = encode_audio_data(audio_tensor)
|
||||||
await websocket.send_json({
|
emit('audio_response', {
|
||||||
"type": "audio_response",
|
'type': 'audio_response',
|
||||||
"text": response_text,
|
'text': response_text,
|
||||||
"audio": audio_base64
|
'audio': audio_base64
|
||||||
|
})
|
||||||
|
|
||||||
|
print(f"[{client_id}] Audio response sent: {len(audio_base64)} bytes")
|
||||||
|
|
||||||
|
except Exception as gen_error:
|
||||||
|
print(f"Error generating audio response: {str(gen_error)}")
|
||||||
|
emit('error', {
|
||||||
|
'type': 'error',
|
||||||
|
'message': "Sorry, there was an error generating the audio response."
|
||||||
})
|
})
|
||||||
else:
|
else:
|
||||||
# If transcription failed, send a generic response
|
# If transcription failed, send a generic response
|
||||||
await websocket.send_json({
|
emit('error', {
|
||||||
"type": "error",
|
'type': 'error',
|
||||||
"message": "Sorry, I couldn't understand what you said. Could you try again?"
|
'message': "Sorry, I couldn't understand what you said. Could you try again?"
|
||||||
})
|
})
|
||||||
|
|
||||||
# Clear buffer and reset silence detection
|
# Clear buffer and reset silence detection
|
||||||
streaming_buffer = []
|
client['streaming_buffer'] = []
|
||||||
energy_window.clear()
|
client['energy_window'].clear()
|
||||||
is_silence = False
|
client['is_silence'] = False
|
||||||
last_active_time = time.time()
|
client['last_active_time'] = time.time()
|
||||||
|
|
||||||
# If buffer gets too large without silence, process it anyway
|
# If buffer gets too large without silence, process it anyway
|
||||||
# This prevents memory issues with very long streams
|
elif len(client['streaming_buffer']) >= 30: # ~6 seconds of audio at 5 chunks/sec
|
||||||
elif len(streaming_buffer) >= 30: # ~6 seconds of audio at 5 chunks/sec
|
print(f"[{client_id}] Processing long audio segment without silence")
|
||||||
print("Buffer limit reached, processing audio")
|
full_audio = torch.cat(client['streaming_buffer'], dim=0)
|
||||||
full_audio = torch.cat(streaming_buffer, dim=0)
|
|
||||||
|
|
||||||
# Process with WhisperX speech-to-text
|
# Process with WhisperX speech-to-text
|
||||||
transcribed_text = await transcribe_audio(full_audio)
|
transcribed_text = transcribe_audio(full_audio)
|
||||||
|
|
||||||
if transcribed_text:
|
if transcribed_text:
|
||||||
context_segments.append(Segment(text=transcribed_text, speaker=speaker_id, audio=full_audio))
|
client['context_segments'].append(
|
||||||
|
Segment(text=transcribed_text, speaker=speaker_id, audio=full_audio)
|
||||||
|
)
|
||||||
|
|
||||||
# Send the transcribed text to client
|
# Send the transcribed text to client
|
||||||
await websocket.send_json({
|
emit('transcription', {
|
||||||
"type": "transcription",
|
'type': 'transcription',
|
||||||
"text": transcribed_text + " (processing continued speech...)"
|
'text': transcribed_text + " (processing continued speech...)"
|
||||||
})
|
})
|
||||||
|
|
||||||
streaming_buffer = []
|
# Keep half of the buffer for context (sliding window approach)
|
||||||
|
half_point = len(client['streaming_buffer']) // 2
|
||||||
|
client['streaming_buffer'] = client['streaming_buffer'][half_point:]
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
print(f"Error processing streaming audio: {str(e)}")
|
print(f"Error processing streaming audio: {str(e)}")
|
||||||
await websocket.send_json({
|
emit('error', {
|
||||||
"type": "error",
|
'type': 'error',
|
||||||
"message": f"Error processing streaming audio: {str(e)}"
|
'message': f"Error processing streaming audio: {str(e)}"
|
||||||
})
|
})
|
||||||
|
|
||||||
elif action == "stop_streaming":
|
@socketio.on('stop_streaming')
|
||||||
is_streaming = False
|
def handle_stop_streaming(data):
|
||||||
if streaming_buffer and len(streaming_buffer) > 5: # Only process if there's meaningful audio
|
client_id = request.sid
|
||||||
|
if client_id not in active_clients:
|
||||||
|
return
|
||||||
|
|
||||||
|
client = active_clients[client_id]
|
||||||
|
client['is_streaming'] = False
|
||||||
|
|
||||||
|
if client['streaming_buffer'] and len(client['streaming_buffer']) > 5:
|
||||||
# Process any remaining audio in the buffer
|
# Process any remaining audio in the buffer
|
||||||
full_audio = torch.cat(streaming_buffer, dim=0)
|
full_audio = torch.cat(client['streaming_buffer'], dim=0)
|
||||||
|
|
||||||
# Process with WhisperX speech-to-text
|
# Process with WhisperX speech-to-text
|
||||||
transcribed_text = await transcribe_audio(full_audio)
|
transcribed_text = transcribe_audio(full_audio)
|
||||||
|
|
||||||
if transcribed_text:
|
if transcribed_text:
|
||||||
context_segments.append(Segment(text=transcribed_text, speaker=request.get("speaker", 0), audio=full_audio))
|
client['context_segments'].append(
|
||||||
|
Segment(text=transcribed_text, speaker=data.get("speaker", 0), audio=full_audio)
|
||||||
|
)
|
||||||
|
|
||||||
# Send the transcribed text to client
|
# Send the transcribed text to client
|
||||||
await websocket.send_json({
|
emit('transcription', {
|
||||||
"type": "transcription",
|
'type': 'transcription',
|
||||||
"text": transcribed_text
|
'text': transcribed_text
|
||||||
})
|
})
|
||||||
|
|
||||||
streaming_buffer = []
|
client['streaming_buffer'] = []
|
||||||
await websocket.send_json({
|
emit('streaming_status', {
|
||||||
"type": "streaming_status",
|
'type': 'streaming_status',
|
||||||
"status": "stopped"
|
'status': 'stopped'
|
||||||
})
|
})
|
||||||
|
|
||||||
except WebSocketDisconnect:
|
def stream_audio_to_client(client_id, audio_tensor, text, speaker_id, chunk_size_ms=500):
|
||||||
manager.disconnect(websocket)
|
"""Stream audio to client in chunks to simulate real-time generation"""
|
||||||
print("Client disconnected")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error: {str(e)}")
|
|
||||||
try:
|
try:
|
||||||
await websocket.send_json({
|
if client_id not in active_clients:
|
||||||
"type": "error",
|
print(f"Client {client_id} not found for streaming")
|
||||||
"message": str(e)
|
return
|
||||||
})
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
manager.disconnect(websocket)
|
|
||||||
|
|
||||||
|
# Calculate chunk size in samples
|
||||||
|
chunk_size = int(generator.sample_rate * chunk_size_ms / 1000)
|
||||||
|
total_chunks = math.ceil(audio_tensor.size(0) / chunk_size)
|
||||||
|
|
||||||
|
print(f"Streaming audio in {total_chunks} chunks of {chunk_size_ms}ms each")
|
||||||
|
|
||||||
|
# Send initial response with text but no audio yet
|
||||||
|
socketio.emit('audio_response_start', {
|
||||||
|
'type': 'audio_response_start',
|
||||||
|
'text': text,
|
||||||
|
'total_chunks': total_chunks
|
||||||
|
}, room=client_id)
|
||||||
|
|
||||||
|
# Stream each chunk
|
||||||
|
for i in range(total_chunks):
|
||||||
|
start_idx = i * chunk_size
|
||||||
|
end_idx = min(start_idx + chunk_size, audio_tensor.size(0))
|
||||||
|
|
||||||
|
# Extract chunk
|
||||||
|
chunk = audio_tensor[start_idx:end_idx]
|
||||||
|
|
||||||
|
# Encode chunk
|
||||||
|
chunk_base64 = encode_audio_data(chunk)
|
||||||
|
|
||||||
|
# Send chunk
|
||||||
|
socketio.emit('audio_response_chunk', {
|
||||||
|
'type': 'audio_response_chunk',
|
||||||
|
'chunk_index': i,
|
||||||
|
'total_chunks': total_chunks,
|
||||||
|
'audio': chunk_base64,
|
||||||
|
'is_last': i == total_chunks - 1
|
||||||
|
}, room=client_id)
|
||||||
|
|
||||||
|
# Brief pause between chunks to simulate streaming
|
||||||
|
time.sleep(0.1)
|
||||||
|
|
||||||
|
# Send completion message
|
||||||
|
socketio.emit('audio_response_complete', {
|
||||||
|
'type': 'audio_response_complete',
|
||||||
|
'text': text
|
||||||
|
}, room=client_id)
|
||||||
|
|
||||||
|
print(f"Audio streaming complete: {total_chunks} chunks sent")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error streaming audio to client: {str(e)}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
print(f"\n{'='*60}")
|
||||||
|
print(f"🔊 Sesame AI Voice Chat Server (Flask Implementation)")
|
||||||
|
print(f"{'='*60}")
|
||||||
|
print(f"📡 Server Information:")
|
||||||
|
print(f" - Local URL: http://localhost:5000")
|
||||||
|
print(f" - Network URL: http://<your-ip-address>:5000")
|
||||||
|
print(f" - WebSocket: ws://<your-ip-address>:5000/socket.io")
|
||||||
|
print(f"{'='*60}")
|
||||||
|
print(f"💡 To make this server public:")
|
||||||
|
print(f" 1. Ensure port 5000 is open in your firewall")
|
||||||
|
print(f" 2. Set up port forwarding on your router to port 5000")
|
||||||
|
print(f" 3. Or use a service like ngrok with: ngrok http 5000")
|
||||||
|
print(f"{'='*60}")
|
||||||
|
print(f"🌐 Device: {device.upper()}")
|
||||||
|
print(f"🧠 Models loaded: Sesame CSM + WhisperX ({asr_model.device})")
|
||||||
|
print(f"🔧 Serving from: {os.path.join(base_dir, 'index.html')}")
|
||||||
|
print(f"{'='*60}")
|
||||||
|
print(f"Ready to receive connections! Press Ctrl+C to stop the server.\n")
|
||||||
|
|
||||||
|
socketio.run(app, host="0.0.0.0", port=5000, debug=False)
|
||||||
852
Backend/voice-chat.js
Normal file
852
Backend/voice-chat.js
Normal file
@@ -0,0 +1,852 @@
|
|||||||
|
/**
|
||||||
|
* Sesame AI Voice Chat Client
|
||||||
|
*
|
||||||
|
* A web client that connects to a Sesame AI voice chat server and enables
|
||||||
|
* real-time voice conversation with an AI assistant.
|
||||||
|
*/
|
||||||
|
|
||||||
|
// Configuration constants
|
||||||
|
const SERVER_URL = window.location.hostname === 'localhost' ?
|
||||||
|
'http://localhost:5000' : window.location.origin;
|
||||||
|
const ENERGY_WINDOW_SIZE = 15;
|
||||||
|
const CLIENT_SILENCE_DURATION_MS = 750;
|
||||||
|
|
||||||
|
// DOM elements
|
||||||
|
const elements = {
|
||||||
|
conversation: null,
|
||||||
|
streamButton: null,
|
||||||
|
clearButton: null,
|
||||||
|
thresholdSlider: null,
|
||||||
|
thresholdValue: null,
|
||||||
|
visualizerCanvas: null,
|
||||||
|
visualizerLabel: null,
|
||||||
|
volumeLevel: null,
|
||||||
|
statusDot: null,
|
||||||
|
statusText: null,
|
||||||
|
speakerSelection: null,
|
||||||
|
autoPlayResponses: null,
|
||||||
|
showVisualizer: null
|
||||||
|
};
|
||||||
|
|
||||||
|
// Application state
|
||||||
|
const state = {
|
||||||
|
socket: null,
|
||||||
|
audioContext: null,
|
||||||
|
analyser: null,
|
||||||
|
microphone: null,
|
||||||
|
streamProcessor: null,
|
||||||
|
isStreaming: false,
|
||||||
|
isSpeaking: false,
|
||||||
|
silenceThreshold: 0.01,
|
||||||
|
energyWindow: [],
|
||||||
|
silenceTimer: null,
|
||||||
|
volumeUpdateInterval: null,
|
||||||
|
visualizerAnimationFrame: null,
|
||||||
|
currentSpeaker: 0
|
||||||
|
};
|
||||||
|
|
||||||
|
// Visualizer variables
|
||||||
|
let canvasContext = null;
|
||||||
|
let visualizerBufferLength = 0;
|
||||||
|
let visualizerDataArray = null;
|
||||||
|
|
||||||
|
// Initialize the application
|
||||||
|
function initializeApp() {
|
||||||
|
// Initialize the UI elements
|
||||||
|
initializeUIElements();
|
||||||
|
|
||||||
|
// Initialize socket.io connection
|
||||||
|
setupSocketConnection();
|
||||||
|
|
||||||
|
// Setup event listeners
|
||||||
|
setupEventListeners();
|
||||||
|
|
||||||
|
// Initialize visualizer
|
||||||
|
setupVisualizer();
|
||||||
|
|
||||||
|
// Show welcome message
|
||||||
|
addSystemMessage('Welcome to Sesame AI Voice Chat! Click "Start Conversation" to begin.');
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize UI elements
|
||||||
|
function initializeUIElements() {
|
||||||
|
// Store references to UI elements
|
||||||
|
elements.conversation = document.getElementById('conversation');
|
||||||
|
elements.streamButton = document.getElementById('streamButton');
|
||||||
|
elements.clearButton = document.getElementById('clearButton');
|
||||||
|
elements.thresholdSlider = document.getElementById('thresholdSlider');
|
||||||
|
elements.thresholdValue = document.getElementById('thresholdValue');
|
||||||
|
elements.visualizerCanvas = document.getElementById('audioVisualizer');
|
||||||
|
elements.visualizerLabel = document.getElementById('visualizerLabel');
|
||||||
|
elements.volumeLevel = document.getElementById('volumeLevel');
|
||||||
|
elements.statusDot = document.getElementById('statusDot');
|
||||||
|
elements.statusText = document.getElementById('statusText');
|
||||||
|
elements.speakerSelection = document.getElementById('speakerSelect'); // Changed to match HTML
|
||||||
|
elements.autoPlayResponses = document.getElementById('autoPlayResponses');
|
||||||
|
elements.showVisualizer = document.getElementById('showVisualizer');
|
||||||
|
}
|
||||||
|
|
||||||
|
// Setup Socket.IO connection
|
||||||
|
function setupSocketConnection() {
|
||||||
|
state.socket = io(SERVER_URL);
|
||||||
|
|
||||||
|
// Connection events
|
||||||
|
state.socket.on('connect', () => {
|
||||||
|
console.log('Connected to server');
|
||||||
|
updateConnectionStatus(true);
|
||||||
|
});
|
||||||
|
|
||||||
|
state.socket.on('disconnect', () => {
|
||||||
|
console.log('Disconnected from server');
|
||||||
|
updateConnectionStatus(false);
|
||||||
|
|
||||||
|
// Stop streaming if active
|
||||||
|
if (state.isStreaming) {
|
||||||
|
stopStreaming(false);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
state.socket.on('error', (data) => {
|
||||||
|
console.error('Socket error:', data.message);
|
||||||
|
addSystemMessage(`Error: ${data.message}`);
|
||||||
|
});
|
||||||
|
|
||||||
|
// Register message handlers
|
||||||
|
state.socket.on('audio_response', handleAudioResponse);
|
||||||
|
state.socket.on('transcription', handleTranscription);
|
||||||
|
state.socket.on('context_updated', handleContextUpdate);
|
||||||
|
state.socket.on('streaming_status', handleStreamingStatus);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Setup event listeners
|
||||||
|
function setupEventListeners() {
|
||||||
|
// Stream button
|
||||||
|
elements.streamButton.addEventListener('click', toggleStreaming);
|
||||||
|
|
||||||
|
// Clear button
|
||||||
|
elements.clearButton.addEventListener('click', clearConversation);
|
||||||
|
|
||||||
|
// Threshold slider
|
||||||
|
elements.thresholdSlider.addEventListener('input', updateThreshold);
|
||||||
|
|
||||||
|
// Speaker selection
|
||||||
|
elements.speakerSelection.addEventListener('change', () => {
|
||||||
|
state.currentSpeaker = parseInt(elements.speakerSelection.value, 10);
|
||||||
|
});
|
||||||
|
|
||||||
|
// Visualizer toggle
|
||||||
|
elements.showVisualizer.addEventListener('change', toggleVisualizerVisibility);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Setup audio visualizer
|
||||||
|
function setupVisualizer() {
|
||||||
|
if (!elements.visualizerCanvas) return;
|
||||||
|
|
||||||
|
canvasContext = elements.visualizerCanvas.getContext('2d');
|
||||||
|
|
||||||
|
// Set canvas dimensions
|
||||||
|
elements.visualizerCanvas.width = elements.visualizerCanvas.offsetWidth;
|
||||||
|
elements.visualizerCanvas.height = elements.visualizerCanvas.offsetHeight;
|
||||||
|
|
||||||
|
// Initialize the visualizer
|
||||||
|
drawVisualizer();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update connection status UI
|
||||||
|
function updateConnectionStatus(isConnected) {
|
||||||
|
elements.statusDot.classList.toggle('active', isConnected);
|
||||||
|
elements.statusText.textContent = isConnected ? 'Connected' : 'Disconnected';
|
||||||
|
}
|
||||||
|
|
||||||
|
// Toggle streaming state
|
||||||
|
function toggleStreaming() {
|
||||||
|
if (state.isStreaming) {
|
||||||
|
stopStreaming(true);
|
||||||
|
} else {
|
||||||
|
startStreaming();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start streaming audio to the server
|
||||||
|
function startStreaming() {
|
||||||
|
if (state.isStreaming) return;
|
||||||
|
|
||||||
|
// Request microphone access
|
||||||
|
navigator.mediaDevices.getUserMedia({ audio: true, video: false })
|
||||||
|
.then(stream => {
|
||||||
|
// Show processing state while setting up
|
||||||
|
elements.streamButton.innerHTML = '<i class="fas fa-spinner fa-spin"></i> Initializing...';
|
||||||
|
|
||||||
|
// Create audio context
|
||||||
|
state.audioContext = new (window.AudioContext || window.webkitAudioContext)();
|
||||||
|
|
||||||
|
// Create microphone source
|
||||||
|
state.microphone = state.audioContext.createMediaStreamSource(stream);
|
||||||
|
|
||||||
|
// Create analyser for visualizer
|
||||||
|
state.analyser = state.audioContext.createAnalyser();
|
||||||
|
state.analyser.fftSize = 256;
|
||||||
|
visualizerBufferLength = state.analyser.frequencyBinCount;
|
||||||
|
visualizerDataArray = new Uint8Array(visualizerBufferLength);
|
||||||
|
|
||||||
|
// Connect microphone to analyser
|
||||||
|
state.microphone.connect(state.analyser);
|
||||||
|
|
||||||
|
// Create script processor for audio processing
|
||||||
|
const bufferSize = 4096;
|
||||||
|
state.streamProcessor = state.audioContext.createScriptProcessor(bufferSize, 1, 1);
|
||||||
|
|
||||||
|
// Set up audio processing callback
|
||||||
|
state.streamProcessor.onaudioprocess = handleAudioProcess;
|
||||||
|
|
||||||
|
// Connect the processors
|
||||||
|
state.analyser.connect(state.streamProcessor);
|
||||||
|
state.streamProcessor.connect(state.audioContext.destination);
|
||||||
|
|
||||||
|
// Update UI
|
||||||
|
state.isStreaming = true;
|
||||||
|
elements.streamButton.innerHTML = '<i class="fas fa-microphone"></i> Listening...';
|
||||||
|
elements.streamButton.classList.add('recording');
|
||||||
|
|
||||||
|
// Initialize energy window
|
||||||
|
state.energyWindow = [];
|
||||||
|
|
||||||
|
// Start volume meter updates
|
||||||
|
state.volumeUpdateInterval = setInterval(updateVolumeMeter, 100);
|
||||||
|
|
||||||
|
// Start visualizer if enabled
|
||||||
|
if (elements.showVisualizer.checked && !state.visualizerAnimationFrame) {
|
||||||
|
drawVisualizer();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Show starting message
|
||||||
|
addSystemMessage('Listening... Speak clearly into your microphone.');
|
||||||
|
|
||||||
|
// Notify the server that we're starting
|
||||||
|
state.socket.emit('stream_audio', {
|
||||||
|
audio: '',
|
||||||
|
speaker: state.currentSpeaker
|
||||||
|
});
|
||||||
|
})
|
||||||
|
.catch(err => {
|
||||||
|
console.error('Error accessing microphone:', err);
|
||||||
|
addSystemMessage(`Error: ${err.message}. Please make sure your microphone is connected and you've granted permission.`);
|
||||||
|
elements.streamButton.innerHTML = '<i class="fas fa-microphone"></i> Start Conversation';
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop streaming audio
|
||||||
|
function stopStreaming(notifyServer = true) {
|
||||||
|
if (!state.isStreaming) return;
|
||||||
|
|
||||||
|
// Update UI first
|
||||||
|
elements.streamButton.innerHTML = '<i class="fas fa-microphone"></i> Start Conversation';
|
||||||
|
elements.streamButton.classList.remove('recording');
|
||||||
|
elements.streamButton.classList.remove('processing');
|
||||||
|
|
||||||
|
// Stop volume meter updates
|
||||||
|
if (state.volumeUpdateInterval) {
|
||||||
|
clearInterval(state.volumeUpdateInterval);
|
||||||
|
state.volumeUpdateInterval = null;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop all audio processing
|
||||||
|
if (state.streamProcessor) {
|
||||||
|
state.streamProcessor.disconnect();
|
||||||
|
state.streamProcessor = null;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (state.analyser) {
|
||||||
|
state.analyser.disconnect();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (state.microphone) {
|
||||||
|
state.microphone.disconnect();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close audio context
|
||||||
|
if (state.audioContext && state.audioContext.state !== 'closed') {
|
||||||
|
state.audioContext.close().catch(err => console.warn('Error closing audio context:', err));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cleanup animation frames
|
||||||
|
if (state.visualizerAnimationFrame) {
|
||||||
|
cancelAnimationFrame(state.visualizerAnimationFrame);
|
||||||
|
state.visualizerAnimationFrame = null;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reset state
|
||||||
|
state.isStreaming = false;
|
||||||
|
state.isSpeaking = false;
|
||||||
|
|
||||||
|
// Notify the server
|
||||||
|
if (notifyServer && state.socket && state.socket.connected) {
|
||||||
|
state.socket.emit('stop_streaming', {
|
||||||
|
speaker: state.currentSpeaker
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// Show message
|
||||||
|
addSystemMessage('Conversation paused. Click "Start Conversation" to resume.');
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle audio processing
|
||||||
|
function handleAudioProcess(event) {
|
||||||
|
const inputData = event.inputBuffer.getChannelData(0);
|
||||||
|
|
||||||
|
// Log audio buffer statistics
|
||||||
|
console.log(`Audio buffer: length=${inputData.length}, sample rate=${state.audioContext.sampleRate}Hz`);
|
||||||
|
|
||||||
|
// Calculate audio energy (volume level)
|
||||||
|
const energy = calculateAudioEnergy(inputData);
|
||||||
|
console.log(`Energy: ${energy.toFixed(6)}, threshold: ${state.silenceThreshold}`);
|
||||||
|
|
||||||
|
// Update energy window for averaging
|
||||||
|
updateEnergyWindow(energy);
|
||||||
|
|
||||||
|
// Calculate average energy
|
||||||
|
const avgEnergy = calculateAverageEnergy();
|
||||||
|
|
||||||
|
// Determine if audio is silent
|
||||||
|
const isSilent = avgEnergy < state.silenceThreshold;
|
||||||
|
console.log(`Silent: ${isSilent ? 'Yes' : 'No'}, avg energy: ${avgEnergy.toFixed(6)}`);
|
||||||
|
|
||||||
|
// Handle speech state based on silence
|
||||||
|
handleSpeechState(isSilent);
|
||||||
|
|
||||||
|
// Only send audio chunk if we detect speech
|
||||||
|
if (!isSilent) {
|
||||||
|
// Create a resampled version at 24kHz for the server
|
||||||
|
// Most WebRTC audio is 48kHz, but we want 24kHz for the model
|
||||||
|
const resampledData = downsampleBuffer(inputData, state.audioContext.sampleRate, 24000);
|
||||||
|
console.log(`Resampled audio: ${state.audioContext.sampleRate}Hz → 24000Hz, new length: ${resampledData.length}`);
|
||||||
|
|
||||||
|
// Send the audio chunk to the server
|
||||||
|
sendAudioChunk(resampledData, state.currentSpeaker);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cleanup audio resources when done
|
||||||
|
function cleanupAudioResources() {
|
||||||
|
// Stop all audio processing
|
||||||
|
if (state.streamProcessor) {
|
||||||
|
state.streamProcessor.disconnect();
|
||||||
|
state.streamProcessor = null;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (state.analyser) {
|
||||||
|
state.analyser.disconnect();
|
||||||
|
state.analyser = null;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (state.microphone) {
|
||||||
|
state.microphone.disconnect();
|
||||||
|
state.microphone = null;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close audio context
|
||||||
|
if (state.audioContext && state.audioContext.state !== 'closed') {
|
||||||
|
state.audioContext.close().catch(err => console.warn('Error closing audio context:', err));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cancel all timers and animation frames
|
||||||
|
if (state.volumeUpdateInterval) {
|
||||||
|
clearInterval(state.volumeUpdateInterval);
|
||||||
|
state.volumeUpdateInterval = null;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (state.visualizerAnimationFrame) {
|
||||||
|
cancelAnimationFrame(state.visualizerAnimationFrame);
|
||||||
|
state.visualizerAnimationFrame = null;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (state.silenceTimer) {
|
||||||
|
clearTimeout(state.silenceTimer);
|
||||||
|
state.silenceTimer = null;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clear conversation history
|
||||||
|
function clearConversation() {
|
||||||
|
if (elements.conversation) {
|
||||||
|
elements.conversation.innerHTML = '';
|
||||||
|
addSystemMessage('Conversation cleared.');
|
||||||
|
|
||||||
|
// Notify server to clear context
|
||||||
|
if (state.socket && state.socket.connected) {
|
||||||
|
state.socket.emit('clear_context');
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate audio energy (volume)
|
||||||
|
function calculateAudioEnergy(buffer) {
|
||||||
|
let sum = 0;
|
||||||
|
for (let i = 0; i < buffer.length; i++) {
|
||||||
|
sum += buffer[i] * buffer[i];
|
||||||
|
}
|
||||||
|
return Math.sqrt(sum / buffer.length);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update energy window for averaging
|
||||||
|
function updateEnergyWindow(energy) {
|
||||||
|
state.energyWindow.push(energy);
|
||||||
|
if (state.energyWindow.length > ENERGY_WINDOW_SIZE) {
|
||||||
|
state.energyWindow.shift();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate average energy from window
|
||||||
|
function calculateAverageEnergy() {
|
||||||
|
if (state.energyWindow.length === 0) return 0;
|
||||||
|
|
||||||
|
const sum = state.energyWindow.reduce((a, b) => a + b, 0);
|
||||||
|
return sum / state.energyWindow.length;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update the threshold from the slider
|
||||||
|
function updateThreshold() {
|
||||||
|
state.silenceThreshold = parseFloat(elements.thresholdSlider.value);
|
||||||
|
elements.thresholdValue.textContent = state.silenceThreshold.toFixed(3);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update the volume meter display
|
||||||
|
function updateVolumeMeter() {
|
||||||
|
if (!state.isStreaming || !state.energyWindow.length) return;
|
||||||
|
|
||||||
|
const avgEnergy = calculateAverageEnergy();
|
||||||
|
|
||||||
|
// Scale energy to percentage (0-100)
|
||||||
|
// Typically, energy values will be very small (e.g., 0.001 to 0.1)
|
||||||
|
// So we multiply by a factor to make it more visible
|
||||||
|
const scaleFactor = 1000;
|
||||||
|
const percentage = Math.min(100, Math.max(0, avgEnergy * scaleFactor));
|
||||||
|
|
||||||
|
// Update volume meter width
|
||||||
|
elements.volumeLevel.style.width = `${percentage}%`;
|
||||||
|
|
||||||
|
// Change color based on level
|
||||||
|
if (percentage > 70) {
|
||||||
|
elements.volumeLevel.style.backgroundColor = '#ff5252';
|
||||||
|
} else if (percentage > 30) {
|
||||||
|
elements.volumeLevel.style.backgroundColor = '#4CAF50';
|
||||||
|
} else {
|
||||||
|
elements.volumeLevel.style.backgroundColor = '#4c84ff';
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle speech/silence state transitions
|
||||||
|
function handleSpeechState(isSilent) {
|
||||||
|
if (state.isSpeaking && isSilent) {
|
||||||
|
// Transition from speaking to silence
|
||||||
|
if (!state.silenceTimer) {
|
||||||
|
state.silenceTimer = setTimeout(() => {
|
||||||
|
// Only consider it a real silence after a certain duration
|
||||||
|
// This prevents detecting brief pauses as the end of speech
|
||||||
|
state.isSpeaking = false;
|
||||||
|
state.silenceTimer = null;
|
||||||
|
}, CLIENT_SILENCE_DURATION_MS);
|
||||||
|
}
|
||||||
|
} else if (state.silenceTimer && !isSilent) {
|
||||||
|
// User started speaking again, cancel the silence timer
|
||||||
|
clearTimeout(state.silenceTimer);
|
||||||
|
state.silenceTimer = null;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update speaking state for non-silent audio
|
||||||
|
if (!isSilent) {
|
||||||
|
state.isSpeaking = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send audio chunk to server
|
||||||
|
function sendAudioChunk(audioData, speaker) {
|
||||||
|
if (!state.socket || !state.socket.connected) {
|
||||||
|
console.warn('Socket not connected');
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
console.log(`Preparing audio chunk: length=${audioData.length}, speaker=${speaker}`);
|
||||||
|
|
||||||
|
// Check for NaN or invalid values
|
||||||
|
let hasInvalidValues = false;
|
||||||
|
for (let i = 0; i < audioData.length; i++) {
|
||||||
|
if (isNaN(audioData[i]) || !isFinite(audioData[i])) {
|
||||||
|
hasInvalidValues = true;
|
||||||
|
console.warn(`Invalid audio value at index ${i}: ${audioData[i]}`);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (hasInvalidValues) {
|
||||||
|
console.warn('Audio data contains invalid values. Creating silent audio.');
|
||||||
|
audioData = new Float32Array(audioData.length).fill(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
// Create WAV blob
|
||||||
|
const wavData = createWavBlob(audioData, 24000);
|
||||||
|
console.log(`WAV blob created: ${wavData.size} bytes`);
|
||||||
|
|
||||||
|
const reader = new FileReader();
|
||||||
|
|
||||||
|
reader.onloadend = function() {
|
||||||
|
try {
|
||||||
|
// Get base64 data
|
||||||
|
const base64data = reader.result;
|
||||||
|
console.log(`Base64 data created: ${base64data.length} bytes`);
|
||||||
|
|
||||||
|
// Send to server
|
||||||
|
state.socket.emit('stream_audio', {
|
||||||
|
audio: base64data,
|
||||||
|
speaker: speaker
|
||||||
|
});
|
||||||
|
console.log('Audio chunk sent to server');
|
||||||
|
} catch (err) {
|
||||||
|
console.error('Error preparing audio data:', err);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
reader.onerror = function() {
|
||||||
|
console.error('Error reading audio data as base64');
|
||||||
|
};
|
||||||
|
|
||||||
|
reader.readAsDataURL(wavData);
|
||||||
|
} catch (err) {
|
||||||
|
console.error('Error creating WAV data:', err);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create WAV blob from audio data with improved error handling
|
||||||
|
function createWavBlob(audioData, sampleRate) {
|
||||||
|
// Validate input
|
||||||
|
if (!audioData || audioData.length === 0) {
|
||||||
|
console.warn('Empty audio data provided to createWavBlob');
|
||||||
|
audioData = new Float32Array(1024).fill(0); // Create 1024 samples of silence
|
||||||
|
}
|
||||||
|
|
||||||
|
// Function to convert Float32Array to Int16Array for WAV format
|
||||||
|
function floatTo16BitPCM(output, offset, input) {
|
||||||
|
for (let i = 0; i < input.length; i++, offset += 2) {
|
||||||
|
// Ensure values are in -1 to 1 range
|
||||||
|
const s = Math.max(-1, Math.min(1, input[i]));
|
||||||
|
// Convert to 16-bit PCM
|
||||||
|
output.setInt16(offset, s < 0 ? s * 0x8000 : s * 0x7FFF, true);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create WAV header
|
||||||
|
function writeString(view, offset, string) {
|
||||||
|
for (let i = 0; i < string.length; i++) {
|
||||||
|
view.setUint8(offset + i, string.charCodeAt(i));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
// Create WAV file with header - careful with buffer sizes
|
||||||
|
const buffer = new ArrayBuffer(44 + audioData.length * 2);
|
||||||
|
const view = new DataView(buffer);
|
||||||
|
|
||||||
|
// RIFF identifier
|
||||||
|
writeString(view, 0, 'RIFF');
|
||||||
|
|
||||||
|
// File length (will be filled later)
|
||||||
|
view.setUint32(4, 36 + audioData.length * 2, true);
|
||||||
|
|
||||||
|
// WAVE identifier
|
||||||
|
writeString(view, 8, 'WAVE');
|
||||||
|
|
||||||
|
// fmt chunk identifier
|
||||||
|
writeString(view, 12, 'fmt ');
|
||||||
|
|
||||||
|
// fmt chunk length
|
||||||
|
view.setUint32(16, 16, true);
|
||||||
|
|
||||||
|
// Sample format (1 is PCM)
|
||||||
|
view.setUint16(20, 1, true);
|
||||||
|
|
||||||
|
// Mono channel
|
||||||
|
view.setUint16(22, 1, true);
|
||||||
|
|
||||||
|
// Sample rate
|
||||||
|
view.setUint32(24, sampleRate, true);
|
||||||
|
|
||||||
|
// Byte rate (sample rate * block align)
|
||||||
|
view.setUint32(28, sampleRate * 2, true);
|
||||||
|
|
||||||
|
// Block align (channels * bytes per sample)
|
||||||
|
view.setUint16(32, 2, true);
|
||||||
|
|
||||||
|
// Bits per sample
|
||||||
|
view.setUint16(34, 16, true);
|
||||||
|
|
||||||
|
// data chunk identifier
|
||||||
|
writeString(view, 36, 'data');
|
||||||
|
|
||||||
|
// data chunk length
|
||||||
|
view.setUint32(40, audioData.length * 2, true);
|
||||||
|
|
||||||
|
// Write the PCM samples
|
||||||
|
floatTo16BitPCM(view, 44, audioData);
|
||||||
|
|
||||||
|
// Create and return blob
|
||||||
|
return new Blob([view], { type: 'audio/wav' });
|
||||||
|
} catch (err) {
|
||||||
|
console.error('Error in createWavBlob:', err);
|
||||||
|
|
||||||
|
// Create a minimal valid WAV file with silence as fallback
|
||||||
|
const fallbackSamples = new Float32Array(1024).fill(0);
|
||||||
|
const fallbackBuffer = new ArrayBuffer(44 + fallbackSamples.length * 2);
|
||||||
|
const fallbackView = new DataView(fallbackBuffer);
|
||||||
|
|
||||||
|
writeString(fallbackView, 0, 'RIFF');
|
||||||
|
fallbackView.setUint32(4, 36 + fallbackSamples.length * 2, true);
|
||||||
|
writeString(fallbackView, 8, 'WAVE');
|
||||||
|
writeString(fallbackView, 12, 'fmt ');
|
||||||
|
fallbackView.setUint32(16, 16, true);
|
||||||
|
fallbackView.setUint16(20, 1, true);
|
||||||
|
fallbackView.setUint16(22, 1, true);
|
||||||
|
fallbackView.setUint32(24, sampleRate, true);
|
||||||
|
fallbackView.setUint32(28, sampleRate * 2, true);
|
||||||
|
fallbackView.setUint16(32, 2, true);
|
||||||
|
fallbackView.setUint16(34, 16, true);
|
||||||
|
writeString(fallbackView, 36, 'data');
|
||||||
|
fallbackView.setUint32(40, fallbackSamples.length * 2, true);
|
||||||
|
floatTo16BitPCM(fallbackView, 44, fallbackSamples);
|
||||||
|
|
||||||
|
return new Blob([fallbackView], { type: 'audio/wav' });
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Draw audio visualizer
|
||||||
|
function drawVisualizer() {
|
||||||
|
if (!canvasContext) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
state.visualizerAnimationFrame = requestAnimationFrame(drawVisualizer);
|
||||||
|
|
||||||
|
// Skip drawing if visualizer is hidden
|
||||||
|
if (!elements.showVisualizer.checked) {
|
||||||
|
if (elements.visualizerCanvas.style.opacity !== '0') {
|
||||||
|
elements.visualizerCanvas.style.opacity = '0';
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
} else if (elements.visualizerCanvas.style.opacity !== '1') {
|
||||||
|
elements.visualizerCanvas.style.opacity = '1';
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get frequency data if available
|
||||||
|
if (state.isStreaming && state.analyser) {
|
||||||
|
try {
|
||||||
|
state.analyser.getByteFrequencyData(visualizerDataArray);
|
||||||
|
} catch (e) {
|
||||||
|
console.warn('Error getting frequency data:', e);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Fade out when not streaming
|
||||||
|
for (let i = 0; i < visualizerDataArray.length; i++) {
|
||||||
|
visualizerDataArray[i] = Math.max(0, visualizerDataArray[i] - 5);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clear canvas
|
||||||
|
canvasContext.fillStyle = 'rgb(0, 0, 0)';
|
||||||
|
canvasContext.fillRect(0, 0, elements.visualizerCanvas.width, elements.visualizerCanvas.height);
|
||||||
|
|
||||||
|
// Draw gradient bars
|
||||||
|
const width = elements.visualizerCanvas.width;
|
||||||
|
const height = elements.visualizerCanvas.height;
|
||||||
|
const barCount = Math.min(visualizerBufferLength, 64);
|
||||||
|
const barWidth = width / barCount - 1;
|
||||||
|
|
||||||
|
for (let i = 0; i < barCount; i++) {
|
||||||
|
const index = Math.floor(i * visualizerBufferLength / barCount);
|
||||||
|
const value = visualizerDataArray[index];
|
||||||
|
|
||||||
|
// Use logarithmic scale for better audio visualization
|
||||||
|
// This makes low values more visible while still maintaining full range
|
||||||
|
const logFactor = 20;
|
||||||
|
const scaledValue = Math.log(1 + (value / 255) * logFactor) / Math.log(1 + logFactor);
|
||||||
|
const barHeight = scaledValue * height;
|
||||||
|
|
||||||
|
// Position bars
|
||||||
|
const x = i * (barWidth + 1);
|
||||||
|
const y = height - barHeight;
|
||||||
|
|
||||||
|
// Create color gradient based on frequency and amplitude
|
||||||
|
const hue = i / barCount * 360; // Full color spectrum
|
||||||
|
const saturation = 80 + (value / 255 * 20); // Higher values more saturated
|
||||||
|
const lightness = 40 + (value / 255 * 20); // Dynamic brightness based on amplitude
|
||||||
|
|
||||||
|
// Draw main bar
|
||||||
|
canvasContext.fillStyle = `hsl(${hue}, ${saturation}%, ${lightness}%)`;
|
||||||
|
canvasContext.fillRect(x, y, barWidth, barHeight);
|
||||||
|
|
||||||
|
// Add reflection effect
|
||||||
|
if (barHeight > 5) {
|
||||||
|
const gradient = canvasContext.createLinearGradient(
|
||||||
|
x, y,
|
||||||
|
x, y + barHeight * 0.5
|
||||||
|
);
|
||||||
|
gradient.addColorStop(0, `hsla(${hue}, ${saturation}%, ${lightness + 20}%, 0.4)`);
|
||||||
|
gradient.addColorStop(1, `hsla(${hue}, ${saturation}%, ${lightness}%, 0)`);
|
||||||
|
canvasContext.fillStyle = gradient;
|
||||||
|
canvasContext.fillRect(x, y, barWidth, barHeight * 0.5);
|
||||||
|
|
||||||
|
// Add highlight on top of the bar for better 3D effect
|
||||||
|
canvasContext.fillStyle = `hsla(${hue}, ${saturation - 20}%, ${lightness + 30}%, 0.7)`;
|
||||||
|
canvasContext.fillRect(x, y, barWidth, 2);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Show/hide the label
|
||||||
|
elements.visualizerLabel.style.opacity = (state.isStreaming) ? '0' : '0.7';
|
||||||
|
}
|
||||||
|
|
||||||
|
// Toggle visualizer visibility
|
||||||
|
function toggleVisualizerVisibility() {
|
||||||
|
const isVisible = elements.showVisualizer.checked;
|
||||||
|
elements.visualizerCanvas.style.opacity = isVisible ? '1' : '0';
|
||||||
|
|
||||||
|
if (isVisible && state.isStreaming && !state.visualizerAnimationFrame) {
|
||||||
|
drawVisualizer();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle audio response from server
|
||||||
|
function handleAudioResponse(data) {
|
||||||
|
console.log('Received audio response');
|
||||||
|
|
||||||
|
// Create message container
|
||||||
|
const messageElement = document.createElement('div');
|
||||||
|
messageElement.className = 'message ai';
|
||||||
|
|
||||||
|
// Add text content if available
|
||||||
|
if (data.text) {
|
||||||
|
const textElement = document.createElement('p');
|
||||||
|
textElement.textContent = data.text;
|
||||||
|
messageElement.appendChild(textElement);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create and configure audio element
|
||||||
|
const audioElement = document.createElement('audio');
|
||||||
|
audioElement.controls = true;
|
||||||
|
audioElement.className = 'audio-player';
|
||||||
|
|
||||||
|
// Set audio source
|
||||||
|
const audioSource = document.createElement('source');
|
||||||
|
audioSource.src = data.audio;
|
||||||
|
audioSource.type = 'audio/wav';
|
||||||
|
|
||||||
|
// Add fallback text
|
||||||
|
audioElement.textContent = 'Your browser does not support the audio element.';
|
||||||
|
|
||||||
|
// Assemble audio element
|
||||||
|
audioElement.appendChild(audioSource);
|
||||||
|
messageElement.appendChild(audioElement);
|
||||||
|
|
||||||
|
// Add timestamp
|
||||||
|
const timeElement = document.createElement('span');
|
||||||
|
timeElement.className = 'message-time';
|
||||||
|
timeElement.textContent = new Date().toLocaleTimeString();
|
||||||
|
messageElement.appendChild(timeElement);
|
||||||
|
|
||||||
|
// Add to conversation
|
||||||
|
elements.conversation.appendChild(messageElement);
|
||||||
|
|
||||||
|
// Auto-scroll to bottom
|
||||||
|
elements.conversation.scrollTop = elements.conversation.scrollHeight;
|
||||||
|
|
||||||
|
// Auto-play if enabled
|
||||||
|
if (elements.autoPlayResponses.checked) {
|
||||||
|
audioElement.play()
|
||||||
|
.catch(err => {
|
||||||
|
console.warn('Auto-play failed:', err);
|
||||||
|
addSystemMessage('Auto-play failed. Please click play to hear the response.');
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// Re-enable stream button after processing is complete
|
||||||
|
if (state.isStreaming) {
|
||||||
|
elements.streamButton.innerHTML = '<i class="fas fa-microphone"></i> Listening...';
|
||||||
|
elements.streamButton.classList.add('recording');
|
||||||
|
elements.streamButton.classList.remove('processing');
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle transcription response from server
|
||||||
|
function handleTranscription(data) {
|
||||||
|
console.log('Received transcription:', data.text);
|
||||||
|
|
||||||
|
// Create message element
|
||||||
|
const messageElement = document.createElement('div');
|
||||||
|
messageElement.className = 'message user';
|
||||||
|
|
||||||
|
// Add text content
|
||||||
|
const textElement = document.createElement('p');
|
||||||
|
textElement.textContent = data.text;
|
||||||
|
messageElement.appendChild(textElement);
|
||||||
|
|
||||||
|
// Add timestamp
|
||||||
|
const timeElement = document.createElement('span');
|
||||||
|
timeElement.className = 'message-time';
|
||||||
|
timeElement.textContent = new Date().toLocaleTimeString();
|
||||||
|
messageElement.appendChild(timeElement);
|
||||||
|
|
||||||
|
// Add to conversation
|
||||||
|
elements.conversation.appendChild(messageElement);
|
||||||
|
|
||||||
|
// Auto-scroll to bottom
|
||||||
|
elements.conversation.scrollTop = elements.conversation.scrollHeight;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle context update from server
|
||||||
|
function handleContextUpdate(data) {
|
||||||
|
console.log('Context updated:', data.message);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle streaming status updates from server
|
||||||
|
function handleStreamingStatus(data) {
|
||||||
|
console.log('Streaming status:', data.status);
|
||||||
|
|
||||||
|
if (data.status === 'stopped') {
|
||||||
|
// Reset UI if needed
|
||||||
|
if (state.isStreaming) {
|
||||||
|
stopStreaming(false); // Don't send to server since this came from server
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add a system message to the conversation
|
||||||
|
function addSystemMessage(message) {
|
||||||
|
const messageElement = document.createElement('div');
|
||||||
|
messageElement.className = 'message system';
|
||||||
|
messageElement.textContent = message;
|
||||||
|
elements.conversation.appendChild(messageElement);
|
||||||
|
|
||||||
|
// Auto-scroll to bottom
|
||||||
|
elements.conversation.scrollTop = elements.conversation.scrollHeight;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Downsample audio buffer to target sample rate
|
||||||
|
function downsampleBuffer(buffer, originalSampleRate, targetSampleRate) {
|
||||||
|
if (originalSampleRate === targetSampleRate) {
|
||||||
|
return buffer;
|
||||||
|
}
|
||||||
|
|
||||||
|
const ratio = originalSampleRate / targetSampleRate;
|
||||||
|
const newLength = Math.round(buffer.length / ratio);
|
||||||
|
const result = new Float32Array(newLength);
|
||||||
|
|
||||||
|
for (let i = 0; i < newLength; i++) {
|
||||||
|
const pos = Math.round(i * ratio);
|
||||||
|
result[i] = buffer[pos];
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize the application when DOM is fully loaded
|
||||||
|
document.addEventListener('DOMContentLoaded', initializeApp);
|
||||||
|
|
||||||
Reference in New Issue
Block a user