Demo Update 22

This commit is contained in:
2025-03-30 09:17:31 -04:00
parent 8695dd0297
commit e69d9c5da1
2 changed files with 496 additions and 463 deletions

View File

@@ -3,7 +3,7 @@
<head> <head>
<meta charset="UTF-8"> <meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0"> <meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Live Voice Assistant with CSM</title> <title>Real-Time Voice Assistant</title>
<script src="https://cdn.socket.io/4.6.0/socket.io.min.js"></script> <script src="https://cdn.socket.io/4.6.0/socket.io.min.js"></script>
<style> <style>
body { body {
@@ -89,33 +89,39 @@
transition: all 0.3s ease; transition: all 0.3s ease;
} }
#talkButton { #micButton {
background-color: #4CAF50; background-color: #4CAF50;
color: white; color: white;
width: 200px; width: 200px;
box-shadow: 0 4px 8px rgba(76, 175, 80, 0.3); box-shadow: 0 4px 8px rgba(76, 175, 80, 0.3);
} }
#talkButton:hover { #micButton:hover {
background-color: #45a049; background-color: #45a049;
transform: translateY(-2px); transform: translateY(-2px);
} }
#talkButton.recording { #micButton.listening {
background-color: #f44336; background-color: #4CAF50;
box-shadow: 0 0 0 rgba(76, 175, 80, 0.4);
animation: pulse 1.5s infinite;
}
#micButton.speaking {
background-color: #f44336;
box-shadow: 0 0 0 rgba(244, 67, 54, 0.4);
animation: pulse 1.5s infinite; animation: pulse 1.5s infinite;
box-shadow: 0 4px 8px rgba(244, 67, 54, 0.3);
} }
@keyframes pulse { @keyframes pulse {
0% { 0% {
transform: scale(1); box-shadow: 0 0 0 0 rgba(76, 175, 80, 0.4);
} }
50% { 70% {
transform: scale(1.05); box-shadow: 0 0 0 15px rgba(76, 175, 80, 0);
} }
100% { 100% {
transform: scale(1); box-shadow: 0 0 0 0 rgba(76, 175, 80, 0);
} }
} }
@@ -126,45 +132,24 @@
color: #657786; color: #657786;
} }
.hidden {
display: none;
}
.transcription-info {
font-size: 0.8em;
color: #888;
margin-top: 4px;
text-align: right;
}
.text-only-indicator {
font-size: 0.8em;
color: #e74c3c;
margin-top: 4px;
font-style: italic;
}
.status-message {
text-align: center;
padding: 8px;
margin: 10px 0;
background-color: #f8f9fa;
border-radius: 5px;
color: #666;
font-size: 0.9em;
}
/* Audio visualizer styles */
.visualizer-container { .visualizer-container {
width: 100%; width: 100%;
height: 120px; height: 100px;
margin: 15px 0; margin: 15px 0;
border-radius: 10px; border-radius: 10px;
overflow: hidden; overflow: hidden;
background-color: #000; background-color: #1a1a1a;
position: relative; position: relative;
} }
.visualizer-container.user {
border: 2px solid #4CAF50;
}
.visualizer-container.ai {
border: 2px solid #2196F3;
}
#visualizer { #visualizer {
width: 100%; width: 100%;
height: 100%; height: 100%;
@@ -176,122 +161,59 @@
top: 10px; top: 10px;
left: 10px; left: 10px;
color: white; color: white;
font-size: 0.8em; font-size: 0.9em;
background-color: rgba(0, 0, 0, 0.5); background-color: rgba(0, 0, 0, 0.5);
padding: 4px 8px; padding: 4px 8px;
border-radius: 4px; border-radius: 4px;
} }
/* Real-time transcription */ .speech-indicator {
.live-transcription {
position: absolute;
bottom: 10px;
left: 10px;
right: 10px;
color: white;
font-size: 0.9em;
background-color: rgba(0, 0, 0, 0.5);
padding: 8px;
border-radius: 4px;
text-align: center;
max-height: 60px;
overflow-y: auto;
font-style: italic;
}
/* Wave animation for active speaker */
.speaking-wave {
display: inline-block; display: inline-block;
margin-left: 5px; width: 10px;
height: 10px;
border-radius: 50%;
margin-right: 5px;
vertical-align: middle; vertical-align: middle;
} }
.speaking-wave span { .user-speaking {
display: inline-block;
width: 3px;
height: 12px;
margin: 0 1px;
background-color: currentColor;
border-radius: 1px;
animation: speakingWave 1s infinite ease-in-out;
}
.speaking-wave span:nth-child(2) {
animation-delay: 0.1s;
}
.speaking-wave span:nth-child(3) {
animation-delay: 0.2s;
}
.speaking-wave span:nth-child(4) {
animation-delay: 0.3s;
}
@keyframes speakingWave {
0%, 100% {
height: 4px;
}
50% {
height: 12px;
}
}
/* Modern switch for visualizer toggle */
.switch-container {
display: flex;
align-items: center;
justify-content: center;
margin-bottom: 10px;
}
.switch {
position: relative;
display: inline-block;
width: 50px;
height: 24px;
margin-left: 10px;
}
.switch input {
opacity: 0;
width: 0;
height: 0;
}
.slider {
position: absolute;
cursor: pointer;
top: 0;
left: 0;
right: 0;
bottom: 0;
background-color: #ccc;
transition: .4s;
border-radius: 24px;
}
.slider:before {
position: absolute;
content: "";
height: 16px;
width: 16px;
left: 4px;
bottom: 4px;
background-color: white;
transition: .4s;
border-radius: 50%;
}
input:checked + .slider {
background-color: #4CAF50; background-color: #4CAF50;
animation: blink 1s infinite;
} }
input:checked + .slider:before { .ai-speaking {
transform: translateX(26px); background-color: #2196F3;
animation: blink 1s infinite;
}
@keyframes blink {
0%, 100% { opacity: 1; }
50% { opacity: 0.4; }
}
.connection-status {
padding: 6px 10px;
border-radius: 4px;
font-size: 0.8em;
margin-top: 10px;
text-align: center;
}
.connection-status.connected {
background-color: #d4edda;
color: #155724;
}
.connection-status.connecting {
background-color: #fff3cd;
color: #856404;
}
.connection-status.disconnected {
background-color: #f8d7da;
color: #721c24;
} }
/* Toast notification for feedback */
.toast { .toast {
position: fixed; position: fixed;
bottom: 20px; bottom: 20px;
@@ -328,38 +250,27 @@
</style> </style>
</head> </head>
<body> <body>
<h1>Live Voice Assistant with CSM</h1> <h1>Real-Time Voice Assistant</h1>
<div id="conversation"></div> <div id="conversation"></div>
<div class="switch-container"> <div class="visualizer-container user" id="visualizerContainer">
<span>Audio Visualizer</span>
<label class="switch">
<input type="checkbox" id="visualizerToggle" checked>
<span class="slider"></span>
</label>
</div>
<div class="visualizer-container" id="visualizerContainer">
<canvas id="visualizer"></canvas> <canvas id="visualizer"></canvas>
<div class="visualizer-label" id="visualizerLabel">Listening...</div> <div class="visualizer-label" id="visualizerLabel">Listening...</div>
<div class="live-transcription" id="liveTranscription"></div>
</div> </div>
<div id="controls"> <div id="controls">
<button id="talkButton">Press to Talk</button> <button id="micButton">Press to Talk</button>
</div> </div>
<div id="status">Connecting to server...</div> <div id="status">Connecting to server...</div>
<script> <script>
const socket = io(); const socket = io();
const talkButton = document.getElementById('talkButton'); const micButton = document.getElementById('micButton');
const conversation = document.getElementById('conversation'); const conversation = document.getElementById('conversation');
const status = document.getElementById('status'); const status = document.getElementById('status');
const visualizerToggle = document.getElementById('visualizerToggle');
const visualizerContainer = document.getElementById('visualizerContainer'); const visualizerContainer = document.getElementById('visualizerContainer');
const visualizerLabel = document.getElementById('visualizerLabel'); const visualizerLabel = document.getElementById('visualizerLabel');
const liveTranscription = document.getElementById('liveTranscription');
const canvas = document.getElementById('visualizer'); const canvas = document.getElementById('visualizer');
const canvasCtx = canvas.getContext('2d'); const canvasCtx = canvas.getContext('2d');
@@ -386,19 +297,6 @@
canvas.height = visualizerContainer.offsetHeight; canvas.height = visualizerContainer.offsetHeight;
} }
// Handle visualizer toggle
visualizerToggle.addEventListener('change', function() {
visualizerActive = this.checked;
visualizerContainer.style.display = visualizerActive ? 'block' : 'none';
if (!visualizerActive && visualizerAnimationId) {
cancelAnimationFrame(visualizerAnimationId);
visualizerAnimationId = null;
} else if (visualizerActive && audioAnalyser) {
drawVisualizer();
}
});
// Connect to server // Connect to server
socket.on('connect', () => { socket.on('connect', () => {
status.textContent = 'Connected to server'; status.textContent = 'Connected to server';
@@ -491,8 +389,8 @@
setupScriptProcessor(stream); setupScriptProcessor(stream);
} }
// Setup talk button // Setup mic button
talkButton.addEventListener('click', toggleTalking); micButton.addEventListener('click', toggleTalking);
// Setup keyboard shortcuts // Setup keyboard shortcuts
document.addEventListener('keydown', (e) => { document.addEventListener('keydown', (e) => {
@@ -618,8 +516,8 @@
if (!sessionActive || isAITalking) return; if (!sessionActive || isAITalking) return;
isStreaming = true; isStreaming = true;
talkButton.classList.add('recording'); micButton.classList.add('listening');
talkButton.textContent = 'Release to Stop'; micButton.textContent = 'Release to Stop';
status.textContent = 'Listening...'; status.textContent = 'Listening...';
visualizerLabel.textContent = 'You are speaking...'; visualizerLabel.textContent = 'You are speaking...';
@@ -630,10 +528,6 @@
// Tell server we're starting to speak // Tell server we're starting to speak
socket.emit('start_speaking'); socket.emit('start_speaking');
// Clear previous transcriptions
liveTranscription.textContent = '';
liveTranscription.classList.remove('hidden');
} }
// Stop talking to the assistant // Stop talking to the assistant
@@ -641,15 +535,12 @@
if (!isStreaming) return; if (!isStreaming) return;
isStreaming = false; isStreaming = false;
talkButton.classList.remove('recording'); micButton.classList.remove('listening');
talkButton.textContent = 'Press to Talk'; micButton.textContent = 'Press to Talk';
status.textContent = 'Processing...'; status.textContent = 'Processing...';
// Tell server we're done speaking // Tell server we're done speaking
socket.emit('stop_speaking'); socket.emit('stop_speaking');
// Hide live transcription temporarily
liveTranscription.classList.add('hidden');
} }
// Send audio chunk to server // Send audio chunk to server
@@ -702,8 +593,7 @@
// Handle real-time transcription // Handle real-time transcription
socket.on('live_transcription', (data) => { socket.on('live_transcription', (data) => {
liveTranscription.textContent = data.text || '...'; visualizerLabel.textContent = data.text || '...';
liveTranscription.classList.remove('hidden');
}); });
// Handle final transcription // Handle final transcription
@@ -749,8 +639,8 @@
speakingWave.remove(); speakingWave.remove();
} }
// Re-enable talk button if it was disabled // Re-enable mic button if it was disabled
talkButton.disabled = false; micButton.disabled = false;
}); });
// Legacy handler for text-only responses // Legacy handler for text-only responses

View File

@@ -8,14 +8,14 @@ import numpy as np
from flask import Flask, render_template, request from flask import Flask, render_template, request
from flask_socketio import SocketIO, emit from flask_socketio import SocketIO, emit
from transformers import AutoModelForCausalLM, AutoTokenizer from transformers import AutoModelForCausalLM, AutoTokenizer
from collections import deque import threading
import queue
import requests import requests
import huggingface_hub import huggingface_hub
from generator import load_csm_1b, Segment from generator import load_csm_1b, Segment
import threading from collections import deque
import queue
import asyncio
import json import json
import webrtcvad # For voice activity detection
# Configure environment with longer timeouts # Configure environment with longer timeouts
os.environ["HF_HUB_DOWNLOAD_TIMEOUT"] = "600" # 10 minutes timeout for downloads os.environ["HF_HUB_DOWNLOAD_TIMEOUT"] = "600" # 10 minutes timeout for downloads
@@ -28,7 +28,7 @@ app = Flask(__name__)
app.config['SECRET_KEY'] = 'your-secret-key' app.config['SECRET_KEY'] = 'your-secret-key'
socketio = SocketIO(app, cors_allowed_origins="*", async_mode='eventlet') socketio = SocketIO(app, cors_allowed_origins="*", async_mode='eventlet')
# Explicitly check for CUDA and print more detailed info # Explicitly check for CUDA and print detailed info
print("\n=== CUDA Information ===") print("\n=== CUDA Information ===")
if torch.cuda.is_available(): if torch.cuda.is_available():
print(f"CUDA is available") print(f"CUDA is available")
@@ -47,18 +47,9 @@ try:
except: except:
print("cuDNN is not available (libcudnn_ops_infer.so.8 not found)") print("cuDNN is not available (libcudnn_ops_infer.so.8 not found)")
# Check for other compute platforms # Determine compute device
if torch.backends.mps.is_available():
print("MPS (Apple Silicon) is available")
else:
print("MPS is not available")
print("========================\n")
# Check for CUDA availability and handle potential CUDA/cuDNN issues
try: try:
if torch.cuda.is_available(): if torch.cuda.is_available():
# Try to initialize CUDA to check if libraries are properly loaded
_ = torch.zeros(1).cuda()
device = "cuda" device = "cuda"
whisper_compute_type = "float16" whisper_compute_type = "float16"
print("🟢 CUDA is available and initialized successfully") print("🟢 CUDA is available and initialized successfully")
@@ -83,14 +74,42 @@ whisper_model = None
csm_generator = None csm_generator = None
llm_model = None llm_model = None
llm_tokenizer = None llm_tokenizer = None
vad = None
# Constants
SAMPLE_RATE = 16000 # For VAD
VAD_FRAME_SIZE = 480 # 30ms at 16kHz for VAD
VAD_MODE = 3 # Aggressive mode for better results
AUDIO_CHUNK_SIZE = 2400 # 100ms chunks when streaming AI voice
# Audio sample rates
CLIENT_SAMPLE_RATE = 44100 # Browser WebAudio default
WHISPER_SAMPLE_RATE = 16000 # Whisper expects 16kHz
# Session data structures
user_sessions = {} # session_id -> complete session data
# WebRTC ICE servers (STUN/TURN servers for NAT traversal)
ICE_SERVERS = [
{"urls": "stun:stun.l.google.com:19302"},
{"urls": "stun:stun1.l.google.com:19302"}
]
def load_models(): def load_models():
global whisper_model, csm_generator, llm_model, llm_tokenizer """Load all necessary models"""
global whisper_model, csm_generator, llm_model, llm_tokenizer, vad
# Initialize Voice Activity Detector
try:
vad = webrtcvad.Vad(VAD_MODE)
print("Voice Activity Detector initialized")
except Exception as e:
print(f"Error initializing VAD: {e}")
vad = None
# Initialize Faster-Whisper for transcription # Initialize Faster-Whisper for transcription
try: try:
print("Loading Whisper model...") print("Loading Whisper model...")
# Import here to avoid immediate import errors if package is missing
from faster_whisper import WhisperModel from faster_whisper import WhisperModel
whisper_model = WhisperModel("base", device=device, compute_type=whisper_compute_type, download_root="./models/whisper") whisper_model = WhisperModel("base", device=device, compute_type=whisper_compute_type, download_root="./models/whisper")
print("Whisper model loaded successfully") print("Whisper model loaded successfully")
@@ -110,9 +129,8 @@ def load_models():
# Initialize Llama 3.2 model for response generation # Initialize Llama 3.2 model for response generation
try: try:
print("Loading Llama 3.2 model...") print("Loading Llama 3.2 model...")
llm_model_id = "meta-llama/Llama-3.2-1B" # Choose appropriate size based on resources llm_model_id = "meta-llama/Llama-3.2-1B"
llm_tokenizer = AutoTokenizer.from_pretrained(llm_model_id, cache_dir="./models/llama") llm_tokenizer = AutoTokenizer.from_pretrained(llm_model_id, cache_dir="./models/llama")
# Use the right data type based on device
dtype = torch.bfloat16 if device != "cpu" else torch.float32 dtype = torch.bfloat16 if device != "cpu" else torch.float32
llm_model = AutoModelForCausalLM.from_pretrained( llm_model = AutoModelForCausalLM.from_pretrained(
llm_model_id, llm_model_id,
@@ -126,247 +144,339 @@ def load_models():
print(f"Error loading Llama 3.2 model: {e}") print(f"Error loading Llama 3.2 model: {e}")
print("Will use a fallback response generation method") print("Will use a fallback response generation method")
# Store conversation context
conversation_context = {} # session_id -> context
active_audio_streams = {} # session_id -> stream status
@app.route('/') @app.route('/')
def index(): def index():
"""Serve the main interface"""
return render_template('index.html') return render_template('index.html')
@socketio.on('connect') @socketio.on('connect')
def handle_connect(): def handle_connect():
print(f"Client connected: {request.sid}") """Handle new client connection"""
conversation_context[request.sid] = { session_id = request.sid
print(f"Client connected: {session_id}")
# Initialize session data
user_sessions[session_id] = {
# Conversation context
'segments': [], 'segments': [],
'speakers': [0, 1], # 0 = user, 1 = bot 'conversation_history': [],
'audio_buffer': deque(maxlen=10), # Store recent audio chunks 'is_turn_active': False,
'is_speaking': False,
'last_activity': time.time(), # Audio buffers and state
'active_session': True, 'vad_buffer': deque(maxlen=30), # ~1s of audio at 30fps
'transcription_buffer': [] # For real-time transcription 'audio_buffer': bytearray(),
'is_user_speaking': False,
'last_vad_active': time.time(),
'silence_duration': 0,
'speech_frames': 0,
# AI state
'is_ai_speaking': False,
'should_interrupt_ai': False,
'ai_stream_queue': queue.Queue(),
# WebRTC status
'webrtc_connected': False,
'webrtc_peer_id': None,
# Processing flags
'is_processing': False,
'pending_user_audio': None
} }
emit('ready', {
'message': 'Connection established', # Send config to client
'sample_rate': getattr(csm_generator, 'sample_rate', 24000) if csm_generator else 24000 emit('session_ready', {
'whisper_available': whisper_model is not None,
'csm_available': csm_generator is not None,
'llm_available': llm_model is not None,
'client_sample_rate': CLIENT_SAMPLE_RATE,
'server_sample_rate': getattr(csm_generator, 'sample_rate', 24000) if csm_generator else 24000,
'ice_servers': ICE_SERVERS
}) })
@socketio.on('disconnect') @socketio.on('disconnect')
def handle_disconnect(): def handle_disconnect():
print(f"Client disconnected: {request.sid}") """Handle client disconnection"""
session_id = request.sid session_id = request.sid
print(f"Client disconnected: {session_id}")
# Clean up resources # Clean up resources
if session_id in conversation_context: if session_id in user_sessions:
conversation_context[session_id]['active_session'] = False # Signal any running threads to stop
del conversation_context[session_id] user_sessions[session_id]['should_interrupt_ai'] = True
if session_id in active_audio_streams: # Clean up resources
active_audio_streams[session_id]['active'] = False del user_sessions[session_id]
del active_audio_streams[session_id]
@socketio.on('webrtc_signal')
def handle_webrtc_signal(data):
"""Handle WebRTC signaling for P2P connection establishment"""
session_id = request.sid
if session_id not in user_sessions:
return
# Simply relay the signal to the client
# In a multi-user app, we would route this to the correct peer
emit('webrtc_signal', data)
@socketio.on('webrtc_connected')
def handle_webrtc_connected(data):
"""Client notifies that WebRTC connection is established"""
session_id = request.sid
if session_id not in user_sessions:
return
user_sessions[session_id]['webrtc_connected'] = True
print(f"WebRTC connected for session {session_id}")
emit('ready_for_speech', {'message': 'Ready to start conversation'})
@socketio.on('audio_stream') @socketio.on('audio_stream')
def handle_audio_stream(data): def handle_audio_stream(data):
"""Handle incoming audio stream from client""" """Process incoming audio stream packets from client"""
session_id = request.sid session_id = request.sid
if session_id not in user_sessions:
if session_id not in conversation_context:
return return
context = conversation_context[session_id] session = user_sessions[session_id]
context['last_activity'] = time.time()
# Process different stream events
if data.get('event') == 'start':
# Client is starting to send audio
context['is_speaking'] = True
context['audio_buffer'].clear()
context['transcription_buffer'] = []
print(f"User {session_id} started streaming audio")
# If AI was speaking, interrupt it
if session_id in active_audio_streams and active_audio_streams[session_id]['active']:
active_audio_streams[session_id]['active'] = False
emit('ai_stream_interrupt', {}, room=session_id)
elif data.get('event') == 'data':
# Audio data received
if not context['is_speaking']:
return
# Decode audio chunk
try:
audio_data = base64.b64decode(data.get('audio', ''))
if not audio_data:
return
audio_numpy = np.frombuffer(audio_data, dtype=np.float32)
# Apply a simple noise gate
if np.mean(np.abs(audio_numpy)) < 0.01: # Very quiet
return
audio_tensor = torch.tensor(audio_numpy)
# Add to audio buffer
context['audio_buffer'].append(audio_tensor)
# Real-time transcription (periodic)
if len(context['audio_buffer']) % 3 == 0: # Process every 3 chunks
threading.Thread(
target=process_realtime_transcription,
args=(session_id,),
daemon=True
).start()
except Exception as e:
print(f"Error processing audio chunk: {e}")
elif data.get('event') == 'end':
# Client has finished sending audio
context['is_speaking'] = False
if len(context['audio_buffer']) > 0:
# Process the complete utterance
threading.Thread(
target=process_complete_utterance,
args=(session_id,),
daemon=True
).start()
print(f"User {session_id} stopped streaming audio")
def process_realtime_transcription(session_id):
"""Process incoming audio for real-time transcription"""
if session_id not in conversation_context or not conversation_context[session_id]['active_session']:
return
context = conversation_context[session_id]
if not context['audio_buffer'] or not context['is_speaking']:
return
try: try:
# Combine current buffer for transcription # Decode audio data
buffer_copy = list(context['audio_buffer']) audio_bytes = base64.b64decode(data.get('audio', ''))
if not buffer_copy: if not audio_bytes or len(audio_bytes) < 2: # Need at least one sample
return return
full_audio = torch.cat(buffer_copy, dim=0) # Add to current audio buffer
session['audio_buffer'] += audio_bytes
# Save audio to temporary WAV file for transcription # Check for speech using VAD
temp_audio_path = f"temp_rt_{session_id}.wav" has_speech = detect_speech(audio_bytes, session_id)
# Handle speech state machine
if has_speech:
# Reset silence tracking when speech is detected
session['last_vad_active'] = time.time()
session['silence_duration'] = 0
session['speech_frames'] += 1
# If not already marked as speaking and we have enough speech frames
if not session['is_user_speaking'] and session['speech_frames'] >= 5:
on_speech_started(session_id)
else:
# No speech detected in this frame
if session['is_user_speaking']:
# Calculate silence duration
now = time.time()
session['silence_duration'] = now - session['last_vad_active']
# If silent for more than 0.5 seconds, end speech segment
if session['silence_duration'] > 0.8 and session['speech_frames'] > 8:
on_speech_ended(session_id)
else:
# Not speaking and no speech, just a silent frame
session['speech_frames'] = max(0, session['speech_frames'] - 1)
except Exception as e:
print(f"Error processing audio stream: {e}")
def detect_speech(audio_bytes, session_id):
"""Use VAD to check if audio contains speech"""
if session_id not in user_sessions:
return False
session = user_sessions[session_id]
# Store in VAD buffer for history
session['vad_buffer'].append(audio_bytes)
if vad is None:
# Fallback to simple energy detection
audio_data = np.frombuffer(audio_bytes, dtype=np.int16)
energy = np.mean(np.abs(audio_data)) / 32768.0
return energy > 0.015 # Simple threshold
try:
# Ensure we have the right amount of data for VAD
audio_data = np.frombuffer(audio_bytes, dtype=np.int16)
# If we have too much data, use just the right amount
if len(audio_data) >= VAD_FRAME_SIZE:
frame = audio_data[:VAD_FRAME_SIZE].tobytes()
return vad.is_speech(frame, SAMPLE_RATE)
# If too little data, accumulate in the VAD buffer and check periodically
if len(session['vad_buffer']) >= 3:
# Combine recent chunks to get enough data
combined = bytearray()
for chunk in list(session['vad_buffer'])[-3:]:
combined.extend(chunk)
# Extract the right amount of data
if len(combined) >= VAD_FRAME_SIZE:
frame = combined[:VAD_FRAME_SIZE]
return vad.is_speech(bytes(frame), SAMPLE_RATE)
return False
except Exception as e:
print(f"VAD error: {e}")
return False
def on_speech_started(session_id):
"""Handle start of user speech"""
if session_id not in user_sessions:
return
session = user_sessions[session_id]
# Reset audio buffer
session['audio_buffer'] = bytearray()
session['is_user_speaking'] = True
session['is_turn_active'] = True
# If AI is speaking, we need to interrupt it
if session['is_ai_speaking']:
session['should_interrupt_ai'] = True
emit('ai_interrupted_by_user', room=session_id)
# Notify client that we detected speech
emit('user_speech_start', room=session_id)
def on_speech_ended(session_id):
"""Handle end of user speech segment"""
if session_id not in user_sessions:
return
session = user_sessions[session_id]
# Mark as not speaking anymore
session['is_user_speaking'] = False
session['speech_frames'] = 0
# If no audio or already processing, skip
if len(session['audio_buffer']) < 4000 or session['is_processing']: # At least 250ms of audio
session['audio_buffer'] = bytearray()
return
# Mark as processing to prevent multiple processes
session['is_processing'] = True
# Create a copy of the audio buffer
audio_copy = session['audio_buffer']
session['audio_buffer'] = bytearray()
# Convert audio to the format needed for processing
try:
# Convert to float32 between -1 and 1
audio_np = np.frombuffer(audio_copy, dtype=np.int16).astype(np.float32) / 32768.0
audio_tensor = torch.from_numpy(audio_np)
# Resample to Whisper's expected sample rate if necessary
if CLIENT_SAMPLE_RATE != WHISPER_SAMPLE_RATE:
audio_tensor = torchaudio.functional.resample(
audio_tensor,
orig_freq=CLIENT_SAMPLE_RATE,
new_freq=WHISPER_SAMPLE_RATE
)
# Save as WAV for transcription
temp_audio_path = f"temp_audio_{session_id}.wav"
torchaudio.save( torchaudio.save(
temp_audio_path, temp_audio_path,
full_audio.unsqueeze(0), audio_tensor.unsqueeze(0),
44100 # Assuming 44.1kHz from client WHISPER_SAMPLE_RATE
) )
# Transcribe with Whisper if available # Start transcription and response process in a thread
if whisper_model is not None: threading.Thread(
segments, _ = whisper_model.transcribe(temp_audio_path, beam_size=5) target=process_user_utterance,
text = " ".join([segment.text for segment in segments]) args=(session_id, temp_audio_path, audio_tensor),
daemon=True
).start()
# Notify client that processing has started
emit('processing_speech', room=session_id)
if text.strip():
context['transcription_buffer'].append(text)
# Send partial transcription to client
emit('partial_transcription', {'text': text}, room=session_id)
except Exception as e: except Exception as e:
print(f"Error in realtime transcription: {e}") print(f"Error preparing audio: {e}")
finally: session['is_processing'] = False
# Clean up emit('error', {'message': f'Error processing audio: {str(e)}'}, room=session_id)
if os.path.exists(temp_audio_path):
os.remove(temp_audio_path)
def process_complete_utterance(session_id): def process_user_utterance(session_id, audio_path, audio_tensor):
"""Process completed user utterance, generate response and stream audio back""" """Process user utterance, transcribe and generate response"""
if session_id not in conversation_context or not conversation_context[session_id]['active_session']: if session_id not in user_sessions:
return return
context = conversation_context[session_id] session = user_sessions[session_id]
if not context['audio_buffer']:
return
# Combine audio chunks
full_audio = torch.cat(list(context['audio_buffer']), dim=0)
context['audio_buffer'].clear()
# Save audio to temporary WAV file for transcription
temp_audio_path = f"temp_audio_{session_id}.wav"
torchaudio.save(
temp_audio_path,
full_audio.unsqueeze(0),
44100 # Assuming 44.1kHz from client
)
try: try:
# Try using Whisper first if available # Transcribe audio
if whisper_model is not None: if whisper_model is not None:
user_text = transcribe_with_whisper(temp_audio_path) user_text = transcribe_with_whisper(audio_path)
else: else:
# Fallback to Google's speech recognition # Fallback to another transcription service
user_text = transcribe_with_google(temp_audio_path) user_text = transcribe_fallback(audio_path)
if not user_text: # Clean up temp file
print("No speech detected.") if os.path.exists(audio_path):
emit('error', {'message': 'No speech detected. Please try again.'}, room=session_id) os.remove(audio_path)
# Check if we got meaningful text
if not user_text or len(user_text.strip()) < 2:
emit('no_speech_detected', room=session_id)
session['is_processing'] = False
return return
print(f"Transcribed: {user_text}") print(f"Transcribed: {user_text}")
# Add to conversation segments # Create user segment
user_segment = Segment( user_segment = Segment(
text=user_text, text=user_text,
speaker=0, # User is speaker 0 speaker=0, # User is speaker 0
audio=full_audio audio=audio_tensor
) )
context['segments'].append(user_segment) session['segments'].append(user_segment)
# Generate bot response text # Update conversation history
bot_response = generate_llm_response(user_text, context['segments']) session['conversation_history'].append({
print(f"Bot response: {bot_response}") 'role': 'user',
'text': user_text
})
# Send transcribed text to client # Send transcription to client
emit('transcription', {'text': user_text}, room=session_id) emit('transcription', {'text': user_text}, room=session_id)
# Generate and stream audio response if CSM is available # Generate AI response
ai_response = generate_ai_response(user_text, session_id)
# Send text response to client
emit('ai_response_text', {'text': ai_response}, room=session_id)
# Update conversation history
session['conversation_history'].append({
'role': 'assistant',
'text': ai_response
})
# Generate voice response if CSM is available
if csm_generator is not None: if csm_generator is not None:
# Create stream state object session['is_ai_speaking'] = True
active_audio_streams[session_id] = { session['should_interrupt_ai'] = False
'active': True,
'text': bot_response
}
# Send initial response to prepare client # Begin streaming audio response
emit('ai_stream_start', {
'text': bot_response
}, room=session_id)
# Start audio generation in a separate thread
threading.Thread( threading.Thread(
target=generate_and_stream_audio_realtime, target=stream_ai_response,
args=(bot_response, context['segments'], session_id), args=(ai_response, session_id),
daemon=True daemon=True
).start() ).start()
else:
# Send text-only response if audio generation isn't available
emit('text_response', {'text': bot_response}, room=session_id)
# Add text-only bot response to conversation history
bot_segment = Segment(
text=bot_response,
speaker=1, # Bot is speaker 1
audio=torch.zeros(1) # Placeholder empty audio
)
context['segments'].append(bot_segment)
except Exception as e: except Exception as e:
print(f"Error processing speech: {e}") print(f"Error processing utterance: {e}")
emit('error', {'message': f'Error processing speech: {str(e)}'}, room=session_id) emit('error', {'message': f'Error: {str(e)}'}, room=session_id)
finally: finally:
# Cleanup temp file # Clear processing flag
if os.path.exists(temp_audio_path): if session_id in user_sessions:
os.remove(temp_audio_path) session['is_processing'] = False
def transcribe_with_whisper(audio_path): def transcribe_with_whisper(audio_path):
"""Transcribe audio using Faster-Whisper""" """Transcribe audio using Faster-Whisper"""
@@ -375,49 +485,58 @@ def transcribe_with_whisper(audio_path):
# Collect all text from segments # Collect all text from segments
user_text = "" user_text = ""
for segment in segments: for segment in segments:
segment_text = segment.text.strip() user_text += segment.text.strip() + " "
print(f"[{segment.start:.2f}s -> {segment.end:.2f}s] {segment_text}")
user_text += segment_text + " "
print(f"Transcribed text: {user_text.strip()}")
return user_text.strip() return user_text.strip()
def transcribe_with_google(audio_path): def transcribe_fallback(audio_path):
"""Fallback transcription using Google's speech recognition""" """Fallback transcription using Google's speech recognition"""
import speech_recognition as sr try:
recognizer = sr.Recognizer() import speech_recognition as sr
recognizer = sr.Recognizer()
with sr.AudioFile(audio_path) as source: with sr.AudioFile(audio_path) as source:
audio = recognizer.record(source) audio = recognizer.record(source)
try: try:
text = recognizer.recognize_google(audio) text = recognizer.recognize_google(audio)
return text return text
except sr.UnknownValueError: except sr.UnknownValueError:
return "" return ""
except sr.RequestError: except sr.RequestError:
# If Google API fails, try a basic energy-based VAD approach return "[Speech recognition service unavailable]"
# This is a very basic fallback and won't give good results except ImportError:
return "[Speech detected but transcription failed]" return "[Speech recognition not available]"
def generate_ai_response(user_text, session_id):
"""Generate text response using available LLM"""
if session_id not in user_sessions:
return "I'm sorry, your session has expired."
session = user_sessions[session_id]
def generate_llm_response(user_text, conversation_segments):
"""Generate text response using available model"""
if llm_model is not None and llm_tokenizer is not None: if llm_model is not None and llm_tokenizer is not None:
# Format conversation history for the LLM # Format conversation history for the LLM
conversation_history = "" prompt = "You are a helpful, friendly voice assistant. Keep your responses brief and conversational.\n\n"
for segment in conversation_segments[-5:]: # Use last 5 utterances for context
speaker_name = "User" if segment.speaker == 0 else "Assistant"
conversation_history += f"{speaker_name}: {segment.text}\n"
# Add the current user query # Add recent conversation history (last 6 turns maximum)
conversation_history += f"User: {user_text}\nAssistant:" for entry in session['conversation_history'][-6:]:
if entry['role'] == 'user':
prompt += f"User: {entry['text']}\n"
else:
prompt += f"Assistant: {entry['text']}\n"
# Add current query if not already in history
if not session['conversation_history'] or session['conversation_history'][-1]['role'] != 'user':
prompt += f"User: {user_text}\n"
prompt += "Assistant: "
try: try:
# Generate response # Generate response
inputs = llm_tokenizer(conversation_history, return_tensors="pt").to(device) inputs = llm_tokenizer(prompt, return_tensors="pt").to(device)
output = llm_model.generate( output = llm_model.generate(
inputs.input_ids, inputs.input_ids,
max_new_tokens=150, max_new_tokens=100, # Keep responses shorter for voice
temperature=0.7, temperature=0.7,
top_p=0.9, top_p=0.9,
do_sample=True do_sample=True
@@ -425,43 +544,48 @@ def generate_llm_response(user_text, conversation_segments):
response = llm_tokenizer.decode(output[0][inputs.input_ids.shape[1]:], skip_special_tokens=True) response = llm_tokenizer.decode(output[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
return response.strip() return response.strip()
except Exception as e: except Exception as e:
print(f"Error generating response with LLM: {e}") print(f"Error generating LLM response: {e}")
return fallback_response(user_text) return fallback_response(user_text)
else: else:
return fallback_response(user_text) return fallback_response(user_text)
def fallback_response(user_text): def fallback_response(user_text):
"""Generate a simple fallback response when LLM is not available""" """Generate simple fallback responses when LLM is unavailable"""
# Simple rule-based responses
user_text_lower = user_text.lower() user_text_lower = user_text.lower()
if "hello" in user_text_lower or "hi" in user_text_lower: if "hello" in user_text_lower or "hi" in user_text_lower:
return "Hello! I'm a simple fallback assistant. The main language model couldn't be loaded, so I have limited capabilities." return "Hello! How can I help you today?"
elif "how are you" in user_text_lower: elif "how are you" in user_text_lower:
return "I'm functioning within my limited capabilities. How can I assist you today?" return "I'm doing well, thanks for asking! How about you?"
elif "thank" in user_text_lower: elif "thank" in user_text_lower:
return "You're welcome! Let me know if there's anything else I can help with." return "You're welcome! Happy to help."
elif "bye" in user_text_lower or "goodbye" in user_text_lower: elif "bye" in user_text_lower or "goodbye" in user_text_lower:
return "Goodbye! Have a great day!" return "Goodbye! Have a great day!"
elif any(q in user_text_lower for q in ["what", "who", "where", "when", "why", "how"]): elif any(q in user_text_lower for q in ["what", "who", "where", "when", "why", "how"]):
return "I'm running in fallback mode and can't answer complex questions. Please try again when the main language model is available." return "That's an interesting question. I wish I could provide a better answer in my current fallback mode."
else: else:
return "I understand you said something about that. Unfortunately, I'm running in fallback mode with limited capabilities. Please try again later when the main model is available." return "I see. Tell me more about that."
def generate_and_stream_audio_realtime(text, conversation_segments, session_id): def stream_ai_response(text, session_id):
"""Generate audio response using CSM and stream it in real-time to client""" """Generate and stream audio response in real-time chunks"""
if session_id not in active_audio_streams or not active_audio_streams[session_id]['active']: if session_id not in user_sessions:
return return
session = user_sessions[session_id]
try: try:
# Use the last few conversation segments as context # Signal start of AI speech
context_segments = conversation_segments[-4:] if len(conversation_segments) > 4 else conversation_segments emit('ai_speech_start', room=session_id)
# Use the last few conversation segments as context (up to 4)
context_segments = session['segments'][-4:] if len(session['segments']) > 4 else session['segments']
# Generate audio for bot response # Generate audio for bot response
audio = csm_generator.generate( audio = csm_generator.generate(
@@ -473,23 +597,26 @@ def generate_and_stream_audio_realtime(text, conversation_segments, session_id):
topk=50 topk=50
) )
# Store the full audio for conversation history # Create and store bot segment
bot_segment = Segment( bot_segment = Segment(
text=text, text=text,
speaker=1, # Bot is speaker 1 speaker=1,
audio=audio audio=audio
) )
if session_id in conversation_context and conversation_context[session_id]['active_session']:
conversation_context[session_id]['segments'].append(bot_segment) if session_id in user_sessions:
session['segments'].append(bot_segment)
# Stream audio in small chunks for more responsive playback # Stream audio in small chunks for more responsive playback
chunk_size = 4800 # 200ms at 24kHz chunk_size = AUDIO_CHUNK_SIZE # Size defined in constants
for i in range(0, len(audio), chunk_size): for i in range(0, len(audio), chunk_size):
if session_id not in active_audio_streams or not active_audio_streams[session_id]['active']: # Check if we should stop (user interrupted)
print("Audio streaming interrupted or session ended") if session_id not in user_sessions or session['should_interrupt_ai']:
print("AI speech interrupted")
break break
# Get next chunk
chunk = audio[i:i+chunk_size] chunk = audio[i:i+chunk_size]
# Convert audio chunk to base64 for streaming # Convert audio chunk to base64 for streaming
@@ -499,32 +626,48 @@ def generate_and_stream_audio_realtime(text, conversation_segments, session_id):
audio_b64 = base64.b64encode(audio_bytes.read()).decode('utf-8') audio_b64 = base64.b64encode(audio_bytes.read()).decode('utf-8')
# Send chunk to client # Send chunk to client
socketio.emit('ai_stream_data', { socketio.emit('ai_speech_chunk', {
'audio': audio_b64, 'audio': audio_b64,
'is_last': i + chunk_size >= len(audio) 'is_last': i + chunk_size >= len(audio)
}, room=session_id) }, room=session_id)
# Simulate real-time speech by adding a small delay # Small sleep for more natural pacing
# Remove this in production for faster response time.sleep(0.06) # Slight delay for smoother playback
time.sleep(0.15) # Slight delay for more natural timing
# Signal end of stream # Signal end of AI speech
if session_id in active_audio_streams and active_audio_streams[session_id]['active']: if session_id in user_sessions:
socketio.emit('ai_stream_end', {}, room=session_id) session['is_ai_speaking'] = False
active_audio_streams[session_id]['active'] = False session['is_turn_active'] = False # End conversation turn
socketio.emit('ai_speech_end', room=session_id)
except Exception as e: except Exception as e:
print(f"Error generating or streaming audio: {e}") print(f"Error streaming AI response: {e}")
# Send error message to client if session_id in user_sessions:
if session_id in conversation_context and conversation_context[session_id]['active_session']: session['is_ai_speaking'] = False
socketio.emit('error', { session['is_turn_active'] = False
'message': f'Error generating audio: {str(e)}' socketio.emit('error', {'message': f'Error generating audio: {str(e)}'}, room=session_id)
}, room=session_id) socketio.emit('ai_speech_end', room=session_id)
# Signal stream end to unblock client @socketio.on('interrupt_ai')
socketio.emit('ai_stream_end', {}, room=session_id) def handle_interrupt():
if session_id in active_audio_streams: """Handle explicit AI interruption request from client"""
active_audio_streams[session_id]['active'] = False session_id = request.sid
if session_id in user_sessions:
user_sessions[session_id]['should_interrupt_ai'] = True
emit('ai_interrupted', room=session_id)
@socketio.on('get_config')
def handle_get_config():
"""Send configuration to client"""
session_id = request.sid
if session_id in user_sessions:
emit('config', {
'client_sample_rate': CLIENT_SAMPLE_RATE,
'server_sample_rate': getattr(csm_generator, 'sample_rate', 24000) if csm_generator else 24000,
'whisper_available': whisper_model is not None,
'csm_available': csm_generator is not None,
'ice_servers': ICE_SERVERS
})
if __name__ == '__main__': if __name__ == '__main__':
# Ensure the existing index.html file is in the correct location # Ensure the existing index.html file is in the correct location
@@ -538,6 +681,6 @@ if __name__ == '__main__':
print("Starting model loading...") print("Starting model loading...")
load_models() load_models()
# Start the server with eventlet for better WebSocket performance # Start the server
print("Starting Flask SocketIO server...") print("Starting Flask SocketIO server...")
socketio.run(app, host='0.0.0.0', port=5000, debug=False) socketio.run(app, host='0.0.0.0', port=5000, debug=False)