Demo Update 22

This commit is contained in:
2025-03-30 09:17:31 -04:00
parent 8695dd0297
commit e69d9c5da1
2 changed files with 496 additions and 463 deletions

View File

@@ -3,7 +3,7 @@
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Live Voice Assistant with CSM</title>
<title>Real-Time Voice Assistant</title>
<script src="https://cdn.socket.io/4.6.0/socket.io.min.js"></script>
<style>
body {
@@ -89,33 +89,39 @@
transition: all 0.3s ease;
}
#talkButton {
#micButton {
background-color: #4CAF50;
color: white;
width: 200px;
box-shadow: 0 4px 8px rgba(76, 175, 80, 0.3);
}
#talkButton:hover {
#micButton:hover {
background-color: #45a049;
transform: translateY(-2px);
}
#talkButton.recording {
background-color: #f44336;
#micButton.listening {
background-color: #4CAF50;
box-shadow: 0 0 0 rgba(76, 175, 80, 0.4);
animation: pulse 1.5s infinite;
}
#micButton.speaking {
background-color: #f44336;
box-shadow: 0 0 0 rgba(244, 67, 54, 0.4);
animation: pulse 1.5s infinite;
box-shadow: 0 4px 8px rgba(244, 67, 54, 0.3);
}
@keyframes pulse {
0% {
transform: scale(1);
box-shadow: 0 0 0 0 rgba(76, 175, 80, 0.4);
}
50% {
transform: scale(1.05);
70% {
box-shadow: 0 0 0 15px rgba(76, 175, 80, 0);
}
100% {
transform: scale(1);
box-shadow: 0 0 0 0 rgba(76, 175, 80, 0);
}
}
@@ -126,45 +132,24 @@
color: #657786;
}
.hidden {
display: none;
}
.transcription-info {
font-size: 0.8em;
color: #888;
margin-top: 4px;
text-align: right;
}
.text-only-indicator {
font-size: 0.8em;
color: #e74c3c;
margin-top: 4px;
font-style: italic;
}
.status-message {
text-align: center;
padding: 8px;
margin: 10px 0;
background-color: #f8f9fa;
border-radius: 5px;
color: #666;
font-size: 0.9em;
}
/* Audio visualizer styles */
.visualizer-container {
width: 100%;
height: 120px;
height: 100px;
margin: 15px 0;
border-radius: 10px;
overflow: hidden;
background-color: #000;
background-color: #1a1a1a;
position: relative;
}
.visualizer-container.user {
border: 2px solid #4CAF50;
}
.visualizer-container.ai {
border: 2px solid #2196F3;
}
#visualizer {
width: 100%;
height: 100%;
@@ -176,122 +161,59 @@
top: 10px;
left: 10px;
color: white;
font-size: 0.8em;
font-size: 0.9em;
background-color: rgba(0, 0, 0, 0.5);
padding: 4px 8px;
border-radius: 4px;
}
/* Real-time transcription */
.live-transcription {
position: absolute;
bottom: 10px;
left: 10px;
right: 10px;
color: white;
font-size: 0.9em;
background-color: rgba(0, 0, 0, 0.5);
padding: 8px;
border-radius: 4px;
text-align: center;
max-height: 60px;
overflow-y: auto;
font-style: italic;
}
/* Wave animation for active speaker */
.speaking-wave {
.speech-indicator {
display: inline-block;
margin-left: 5px;
width: 10px;
height: 10px;
border-radius: 50%;
margin-right: 5px;
vertical-align: middle;
}
.speaking-wave span {
display: inline-block;
width: 3px;
height: 12px;
margin: 0 1px;
background-color: currentColor;
border-radius: 1px;
animation: speakingWave 1s infinite ease-in-out;
}
.speaking-wave span:nth-child(2) {
animation-delay: 0.1s;
}
.speaking-wave span:nth-child(3) {
animation-delay: 0.2s;
}
.speaking-wave span:nth-child(4) {
animation-delay: 0.3s;
}
@keyframes speakingWave {
0%, 100% {
height: 4px;
}
50% {
height: 12px;
}
}
/* Modern switch for visualizer toggle */
.switch-container {
display: flex;
align-items: center;
justify-content: center;
margin-bottom: 10px;
}
.switch {
position: relative;
display: inline-block;
width: 50px;
height: 24px;
margin-left: 10px;
}
.switch input {
opacity: 0;
width: 0;
height: 0;
}
.slider {
position: absolute;
cursor: pointer;
top: 0;
left: 0;
right: 0;
bottom: 0;
background-color: #ccc;
transition: .4s;
border-radius: 24px;
}
.slider:before {
position: absolute;
content: "";
height: 16px;
width: 16px;
left: 4px;
bottom: 4px;
background-color: white;
transition: .4s;
border-radius: 50%;
}
input:checked + .slider {
.user-speaking {
background-color: #4CAF50;
animation: blink 1s infinite;
}
input:checked + .slider:before {
transform: translateX(26px);
.ai-speaking {
background-color: #2196F3;
animation: blink 1s infinite;
}
@keyframes blink {
0%, 100% { opacity: 1; }
50% { opacity: 0.4; }
}
.connection-status {
padding: 6px 10px;
border-radius: 4px;
font-size: 0.8em;
margin-top: 10px;
text-align: center;
}
.connection-status.connected {
background-color: #d4edda;
color: #155724;
}
.connection-status.connecting {
background-color: #fff3cd;
color: #856404;
}
.connection-status.disconnected {
background-color: #f8d7da;
color: #721c24;
}
/* Toast notification for feedback */
.toast {
position: fixed;
bottom: 20px;
@@ -328,38 +250,27 @@
</style>
</head>
<body>
<h1>Live Voice Assistant with CSM</h1>
<h1>Real-Time Voice Assistant</h1>
<div id="conversation"></div>
<div class="switch-container">
<span>Audio Visualizer</span>
<label class="switch">
<input type="checkbox" id="visualizerToggle" checked>
<span class="slider"></span>
</label>
</div>
<div class="visualizer-container" id="visualizerContainer">
<div class="visualizer-container user" id="visualizerContainer">
<canvas id="visualizer"></canvas>
<div class="visualizer-label" id="visualizerLabel">Listening...</div>
<div class="live-transcription" id="liveTranscription"></div>
</div>
<div id="controls">
<button id="talkButton">Press to Talk</button>
<button id="micButton">Press to Talk</button>
</div>
<div id="status">Connecting to server...</div>
<script>
const socket = io();
const talkButton = document.getElementById('talkButton');
const micButton = document.getElementById('micButton');
const conversation = document.getElementById('conversation');
const status = document.getElementById('status');
const visualizerToggle = document.getElementById('visualizerToggle');
const visualizerContainer = document.getElementById('visualizerContainer');
const visualizerLabel = document.getElementById('visualizerLabel');
const liveTranscription = document.getElementById('liveTranscription');
const canvas = document.getElementById('visualizer');
const canvasCtx = canvas.getContext('2d');
@@ -386,19 +297,6 @@
canvas.height = visualizerContainer.offsetHeight;
}
// Handle visualizer toggle
visualizerToggle.addEventListener('change', function() {
visualizerActive = this.checked;
visualizerContainer.style.display = visualizerActive ? 'block' : 'none';
if (!visualizerActive && visualizerAnimationId) {
cancelAnimationFrame(visualizerAnimationId);
visualizerAnimationId = null;
} else if (visualizerActive && audioAnalyser) {
drawVisualizer();
}
});
// Connect to server
socket.on('connect', () => {
status.textContent = 'Connected to server';
@@ -491,8 +389,8 @@
setupScriptProcessor(stream);
}
// Setup talk button
talkButton.addEventListener('click', toggleTalking);
// Setup mic button
micButton.addEventListener('click', toggleTalking);
// Setup keyboard shortcuts
document.addEventListener('keydown', (e) => {
@@ -618,8 +516,8 @@
if (!sessionActive || isAITalking) return;
isStreaming = true;
talkButton.classList.add('recording');
talkButton.textContent = 'Release to Stop';
micButton.classList.add('listening');
micButton.textContent = 'Release to Stop';
status.textContent = 'Listening...';
visualizerLabel.textContent = 'You are speaking...';
@@ -630,10 +528,6 @@
// Tell server we're starting to speak
socket.emit('start_speaking');
// Clear previous transcriptions
liveTranscription.textContent = '';
liveTranscription.classList.remove('hidden');
}
// Stop talking to the assistant
@@ -641,15 +535,12 @@
if (!isStreaming) return;
isStreaming = false;
talkButton.classList.remove('recording');
talkButton.textContent = 'Press to Talk';
micButton.classList.remove('listening');
micButton.textContent = 'Press to Talk';
status.textContent = 'Processing...';
// Tell server we're done speaking
socket.emit('stop_speaking');
// Hide live transcription temporarily
liveTranscription.classList.add('hidden');
}
// Send audio chunk to server
@@ -702,8 +593,7 @@
// Handle real-time transcription
socket.on('live_transcription', (data) => {
liveTranscription.textContent = data.text || '...';
liveTranscription.classList.remove('hidden');
visualizerLabel.textContent = data.text || '...';
});
// Handle final transcription
@@ -749,8 +639,8 @@
speakingWave.remove();
}
// Re-enable talk button if it was disabled
talkButton.disabled = false;
// Re-enable mic button if it was disabled
micButton.disabled = false;
});
// Legacy handler for text-only responses

View File

@@ -8,14 +8,14 @@ import numpy as np
from flask import Flask, render_template, request
from flask_socketio import SocketIO, emit
from transformers import AutoModelForCausalLM, AutoTokenizer
from collections import deque
import threading
import queue
import requests
import huggingface_hub
from generator import load_csm_1b, Segment
import threading
import queue
import asyncio
from collections import deque
import json
import webrtcvad # For voice activity detection
# Configure environment with longer timeouts
os.environ["HF_HUB_DOWNLOAD_TIMEOUT"] = "600" # 10 minutes timeout for downloads
@@ -28,7 +28,7 @@ app = Flask(__name__)
app.config['SECRET_KEY'] = 'your-secret-key'
socketio = SocketIO(app, cors_allowed_origins="*", async_mode='eventlet')
# Explicitly check for CUDA and print more detailed info
# Explicitly check for CUDA and print detailed info
print("\n=== CUDA Information ===")
if torch.cuda.is_available():
print(f"CUDA is available")
@@ -47,18 +47,9 @@ try:
except:
print("cuDNN is not available (libcudnn_ops_infer.so.8 not found)")
# Check for other compute platforms
if torch.backends.mps.is_available():
print("MPS (Apple Silicon) is available")
else:
print("MPS is not available")
print("========================\n")
# Check for CUDA availability and handle potential CUDA/cuDNN issues
# Determine compute device
try:
if torch.cuda.is_available():
# Try to initialize CUDA to check if libraries are properly loaded
_ = torch.zeros(1).cuda()
device = "cuda"
whisper_compute_type = "float16"
print("🟢 CUDA is available and initialized successfully")
@@ -75,7 +66,7 @@ except Exception as e:
print("🔴 Falling back to CPU")
device = "cpu"
whisper_compute_type = "int8"
print(f"Using device: {device}")
# Initialize models with proper error handling
@@ -83,14 +74,42 @@ whisper_model = None
csm_generator = None
llm_model = None
llm_tokenizer = None
vad = None
# Constants
SAMPLE_RATE = 16000 # For VAD
VAD_FRAME_SIZE = 480 # 30ms at 16kHz for VAD
VAD_MODE = 3 # Aggressive mode for better results
AUDIO_CHUNK_SIZE = 2400 # 100ms chunks when streaming AI voice
# Audio sample rates
CLIENT_SAMPLE_RATE = 44100 # Browser WebAudio default
WHISPER_SAMPLE_RATE = 16000 # Whisper expects 16kHz
# Session data structures
user_sessions = {} # session_id -> complete session data
# WebRTC ICE servers (STUN/TURN servers for NAT traversal)
ICE_SERVERS = [
{"urls": "stun:stun.l.google.com:19302"},
{"urls": "stun:stun1.l.google.com:19302"}
]
def load_models():
global whisper_model, csm_generator, llm_model, llm_tokenizer
"""Load all necessary models"""
global whisper_model, csm_generator, llm_model, llm_tokenizer, vad
# Initialize Voice Activity Detector
try:
vad = webrtcvad.Vad(VAD_MODE)
print("Voice Activity Detector initialized")
except Exception as e:
print(f"Error initializing VAD: {e}")
vad = None
# Initialize Faster-Whisper for transcription
try:
print("Loading Whisper model...")
# Import here to avoid immediate import errors if package is missing
from faster_whisper import WhisperModel
whisper_model = WhisperModel("base", device=device, compute_type=whisper_compute_type, download_root="./models/whisper")
print("Whisper model loaded successfully")
@@ -110,9 +129,8 @@ def load_models():
# Initialize Llama 3.2 model for response generation
try:
print("Loading Llama 3.2 model...")
llm_model_id = "meta-llama/Llama-3.2-1B" # Choose appropriate size based on resources
llm_model_id = "meta-llama/Llama-3.2-1B"
llm_tokenizer = AutoTokenizer.from_pretrained(llm_model_id, cache_dir="./models/llama")
# Use the right data type based on device
dtype = torch.bfloat16 if device != "cpu" else torch.float32
llm_model = AutoModelForCausalLM.from_pretrained(
llm_model_id,
@@ -126,247 +144,339 @@ def load_models():
print(f"Error loading Llama 3.2 model: {e}")
print("Will use a fallback response generation method")
# Store conversation context
conversation_context = {} # session_id -> context
active_audio_streams = {} # session_id -> stream status
@app.route('/')
def index():
"""Serve the main interface"""
return render_template('index.html')
@socketio.on('connect')
def handle_connect():
print(f"Client connected: {request.sid}")
conversation_context[request.sid] = {
"""Handle new client connection"""
session_id = request.sid
print(f"Client connected: {session_id}")
# Initialize session data
user_sessions[session_id] = {
# Conversation context
'segments': [],
'speakers': [0, 1], # 0 = user, 1 = bot
'audio_buffer': deque(maxlen=10), # Store recent audio chunks
'is_speaking': False,
'last_activity': time.time(),
'active_session': True,
'transcription_buffer': [] # For real-time transcription
'conversation_history': [],
'is_turn_active': False,
# Audio buffers and state
'vad_buffer': deque(maxlen=30), # ~1s of audio at 30fps
'audio_buffer': bytearray(),
'is_user_speaking': False,
'last_vad_active': time.time(),
'silence_duration': 0,
'speech_frames': 0,
# AI state
'is_ai_speaking': False,
'should_interrupt_ai': False,
'ai_stream_queue': queue.Queue(),
# WebRTC status
'webrtc_connected': False,
'webrtc_peer_id': None,
# Processing flags
'is_processing': False,
'pending_user_audio': None
}
emit('ready', {
'message': 'Connection established',
'sample_rate': getattr(csm_generator, 'sample_rate', 24000) if csm_generator else 24000
# Send config to client
emit('session_ready', {
'whisper_available': whisper_model is not None,
'csm_available': csm_generator is not None,
'llm_available': llm_model is not None,
'client_sample_rate': CLIENT_SAMPLE_RATE,
'server_sample_rate': getattr(csm_generator, 'sample_rate', 24000) if csm_generator else 24000,
'ice_servers': ICE_SERVERS
})
@socketio.on('disconnect')
def handle_disconnect():
print(f"Client disconnected: {request.sid}")
"""Handle client disconnection"""
session_id = request.sid
print(f"Client disconnected: {session_id}")
# Clean up resources
if session_id in conversation_context:
conversation_context[session_id]['active_session'] = False
del conversation_context[session_id]
if session_id in user_sessions:
# Signal any running threads to stop
user_sessions[session_id]['should_interrupt_ai'] = True
# Clean up resources
del user_sessions[session_id]
@socketio.on('webrtc_signal')
def handle_webrtc_signal(data):
"""Handle WebRTC signaling for P2P connection establishment"""
session_id = request.sid
if session_id not in user_sessions:
return
if session_id in active_audio_streams:
active_audio_streams[session_id]['active'] = False
del active_audio_streams[session_id]
# Simply relay the signal to the client
# In a multi-user app, we would route this to the correct peer
emit('webrtc_signal', data)
@socketio.on('webrtc_connected')
def handle_webrtc_connected(data):
"""Client notifies that WebRTC connection is established"""
session_id = request.sid
if session_id not in user_sessions:
return
user_sessions[session_id]['webrtc_connected'] = True
print(f"WebRTC connected for session {session_id}")
emit('ready_for_speech', {'message': 'Ready to start conversation'})
@socketio.on('audio_stream')
def handle_audio_stream(data):
"""Handle incoming audio stream from client"""
"""Process incoming audio stream packets from client"""
session_id = request.sid
if session_id not in conversation_context:
if session_id not in user_sessions:
return
context = conversation_context[session_id]
context['last_activity'] = time.time()
# Process different stream events
if data.get('event') == 'start':
# Client is starting to send audio
context['is_speaking'] = True
context['audio_buffer'].clear()
context['transcription_buffer'] = []
print(f"User {session_id} started streaming audio")
# If AI was speaking, interrupt it
if session_id in active_audio_streams and active_audio_streams[session_id]['active']:
active_audio_streams[session_id]['active'] = False
emit('ai_stream_interrupt', {}, room=session_id)
elif data.get('event') == 'data':
# Audio data received
if not context['is_speaking']:
return
# Decode audio chunk
try:
audio_data = base64.b64decode(data.get('audio', ''))
if not audio_data:
return
audio_numpy = np.frombuffer(audio_data, dtype=np.float32)
# Apply a simple noise gate
if np.mean(np.abs(audio_numpy)) < 0.01: # Very quiet
return
audio_tensor = torch.tensor(audio_numpy)
# Add to audio buffer
context['audio_buffer'].append(audio_tensor)
# Real-time transcription (periodic)
if len(context['audio_buffer']) % 3 == 0: # Process every 3 chunks
threading.Thread(
target=process_realtime_transcription,
args=(session_id,),
daemon=True
).start()
except Exception as e:
print(f"Error processing audio chunk: {e}")
elif data.get('event') == 'end':
# Client has finished sending audio
context['is_speaking'] = False
if len(context['audio_buffer']) > 0:
# Process the complete utterance
threading.Thread(
target=process_complete_utterance,
args=(session_id,),
daemon=True
).start()
print(f"User {session_id} stopped streaming audio")
def process_realtime_transcription(session_id):
"""Process incoming audio for real-time transcription"""
if session_id not in conversation_context or not conversation_context[session_id]['active_session']:
return
context = conversation_context[session_id]
if not context['audio_buffer'] or not context['is_speaking']:
return
session = user_sessions[session_id]
try:
# Combine current buffer for transcription
buffer_copy = list(context['audio_buffer'])
if not buffer_copy:
# Decode audio data
audio_bytes = base64.b64decode(data.get('audio', ''))
if not audio_bytes or len(audio_bytes) < 2: # Need at least one sample
return
full_audio = torch.cat(buffer_copy, dim=0)
# Save audio to temporary WAV file for transcription
temp_audio_path = f"temp_rt_{session_id}.wav"
# Add to current audio buffer
session['audio_buffer'] += audio_bytes
# Check for speech using VAD
has_speech = detect_speech(audio_bytes, session_id)
# Handle speech state machine
if has_speech:
# Reset silence tracking when speech is detected
session['last_vad_active'] = time.time()
session['silence_duration'] = 0
session['speech_frames'] += 1
# If not already marked as speaking and we have enough speech frames
if not session['is_user_speaking'] and session['speech_frames'] >= 5:
on_speech_started(session_id)
else:
# No speech detected in this frame
if session['is_user_speaking']:
# Calculate silence duration
now = time.time()
session['silence_duration'] = now - session['last_vad_active']
# If silent for more than 0.5 seconds, end speech segment
if session['silence_duration'] > 0.8 and session['speech_frames'] > 8:
on_speech_ended(session_id)
else:
# Not speaking and no speech, just a silent frame
session['speech_frames'] = max(0, session['speech_frames'] - 1)
except Exception as e:
print(f"Error processing audio stream: {e}")
def detect_speech(audio_bytes, session_id):
"""Use VAD to check if audio contains speech"""
if session_id not in user_sessions:
return False
session = user_sessions[session_id]
# Store in VAD buffer for history
session['vad_buffer'].append(audio_bytes)
if vad is None:
# Fallback to simple energy detection
audio_data = np.frombuffer(audio_bytes, dtype=np.int16)
energy = np.mean(np.abs(audio_data)) / 32768.0
return energy > 0.015 # Simple threshold
try:
# Ensure we have the right amount of data for VAD
audio_data = np.frombuffer(audio_bytes, dtype=np.int16)
# If we have too much data, use just the right amount
if len(audio_data) >= VAD_FRAME_SIZE:
frame = audio_data[:VAD_FRAME_SIZE].tobytes()
return vad.is_speech(frame, SAMPLE_RATE)
# If too little data, accumulate in the VAD buffer and check periodically
if len(session['vad_buffer']) >= 3:
# Combine recent chunks to get enough data
combined = bytearray()
for chunk in list(session['vad_buffer'])[-3:]:
combined.extend(chunk)
# Extract the right amount of data
if len(combined) >= VAD_FRAME_SIZE:
frame = combined[:VAD_FRAME_SIZE]
return vad.is_speech(bytes(frame), SAMPLE_RATE)
return False
except Exception as e:
print(f"VAD error: {e}")
return False
def on_speech_started(session_id):
"""Handle start of user speech"""
if session_id not in user_sessions:
return
session = user_sessions[session_id]
# Reset audio buffer
session['audio_buffer'] = bytearray()
session['is_user_speaking'] = True
session['is_turn_active'] = True
# If AI is speaking, we need to interrupt it
if session['is_ai_speaking']:
session['should_interrupt_ai'] = True
emit('ai_interrupted_by_user', room=session_id)
# Notify client that we detected speech
emit('user_speech_start', room=session_id)
def on_speech_ended(session_id):
"""Handle end of user speech segment"""
if session_id not in user_sessions:
return
session = user_sessions[session_id]
# Mark as not speaking anymore
session['is_user_speaking'] = False
session['speech_frames'] = 0
# If no audio or already processing, skip
if len(session['audio_buffer']) < 4000 or session['is_processing']: # At least 250ms of audio
session['audio_buffer'] = bytearray()
return
# Mark as processing to prevent multiple processes
session['is_processing'] = True
# Create a copy of the audio buffer
audio_copy = session['audio_buffer']
session['audio_buffer'] = bytearray()
# Convert audio to the format needed for processing
try:
# Convert to float32 between -1 and 1
audio_np = np.frombuffer(audio_copy, dtype=np.int16).astype(np.float32) / 32768.0
audio_tensor = torch.from_numpy(audio_np)
# Resample to Whisper's expected sample rate if necessary
if CLIENT_SAMPLE_RATE != WHISPER_SAMPLE_RATE:
audio_tensor = torchaudio.functional.resample(
audio_tensor,
orig_freq=CLIENT_SAMPLE_RATE,
new_freq=WHISPER_SAMPLE_RATE
)
# Save as WAV for transcription
temp_audio_path = f"temp_audio_{session_id}.wav"
torchaudio.save(
temp_audio_path,
full_audio.unsqueeze(0),
44100 # Assuming 44.1kHz from client
audio_tensor.unsqueeze(0),
WHISPER_SAMPLE_RATE
)
# Transcribe with Whisper if available
if whisper_model is not None:
segments, _ = whisper_model.transcribe(temp_audio_path, beam_size=5)
text = " ".join([segment.text for segment in segments])
if text.strip():
context['transcription_buffer'].append(text)
# Send partial transcription to client
emit('partial_transcription', {'text': text}, room=session_id)
# Start transcription and response process in a thread
threading.Thread(
target=process_user_utterance,
args=(session_id, temp_audio_path, audio_tensor),
daemon=True
).start()
# Notify client that processing has started
emit('processing_speech', room=session_id)
except Exception as e:
print(f"Error in realtime transcription: {e}")
finally:
# Clean up
if os.path.exists(temp_audio_path):
os.remove(temp_audio_path)
print(f"Error preparing audio: {e}")
session['is_processing'] = False
emit('error', {'message': f'Error processing audio: {str(e)}'}, room=session_id)
def process_complete_utterance(session_id):
"""Process completed user utterance, generate response and stream audio back"""
if session_id not in conversation_context or not conversation_context[session_id]['active_session']:
def process_user_utterance(session_id, audio_path, audio_tensor):
"""Process user utterance, transcribe and generate response"""
if session_id not in user_sessions:
return
context = conversation_context[session_id]
if not context['audio_buffer']:
return
# Combine audio chunks
full_audio = torch.cat(list(context['audio_buffer']), dim=0)
context['audio_buffer'].clear()
# Save audio to temporary WAV file for transcription
temp_audio_path = f"temp_audio_{session_id}.wav"
torchaudio.save(
temp_audio_path,
full_audio.unsqueeze(0),
44100 # Assuming 44.1kHz from client
)
session = user_sessions[session_id]
try:
# Try using Whisper first if available
# Transcribe audio
if whisper_model is not None:
user_text = transcribe_with_whisper(temp_audio_path)
user_text = transcribe_with_whisper(audio_path)
else:
# Fallback to Google's speech recognition
user_text = transcribe_with_google(temp_audio_path)
# Fallback to another transcription service
user_text = transcribe_fallback(audio_path)
if not user_text:
print("No speech detected.")
emit('error', {'message': 'No speech detected. Please try again.'}, room=session_id)
# Clean up temp file
if os.path.exists(audio_path):
os.remove(audio_path)
# Check if we got meaningful text
if not user_text or len(user_text.strip()) < 2:
emit('no_speech_detected', room=session_id)
session['is_processing'] = False
return
print(f"Transcribed: {user_text}")
# Add to conversation segments
# Create user segment
user_segment = Segment(
text=user_text,
speaker=0, # User is speaker 0
audio=full_audio
audio=audio_tensor
)
context['segments'].append(user_segment)
session['segments'].append(user_segment)
# Generate bot response text
bot_response = generate_llm_response(user_text, context['segments'])
print(f"Bot response: {bot_response}")
# Update conversation history
session['conversation_history'].append({
'role': 'user',
'text': user_text
})
# Send transcribed text to client
# Send transcription to client
emit('transcription', {'text': user_text}, room=session_id)
# Generate and stream audio response if CSM is available
# Generate AI response
ai_response = generate_ai_response(user_text, session_id)
# Send text response to client
emit('ai_response_text', {'text': ai_response}, room=session_id)
# Update conversation history
session['conversation_history'].append({
'role': 'assistant',
'text': ai_response
})
# Generate voice response if CSM is available
if csm_generator is not None:
# Create stream state object
active_audio_streams[session_id] = {
'active': True,
'text': bot_response
}
session['is_ai_speaking'] = True
session['should_interrupt_ai'] = False
# Send initial response to prepare client
emit('ai_stream_start', {
'text': bot_response
}, room=session_id)
# Start audio generation in a separate thread
# Begin streaming audio response
threading.Thread(
target=generate_and_stream_audio_realtime,
args=(bot_response, context['segments'], session_id),
target=stream_ai_response,
args=(ai_response, session_id),
daemon=True
).start()
else:
# Send text-only response if audio generation isn't available
emit('text_response', {'text': bot_response}, room=session_id)
# Add text-only bot response to conversation history
bot_segment = Segment(
text=bot_response,
speaker=1, # Bot is speaker 1
audio=torch.zeros(1) # Placeholder empty audio
)
context['segments'].append(bot_segment)
except Exception as e:
print(f"Error processing speech: {e}")
emit('error', {'message': f'Error processing speech: {str(e)}'}, room=session_id)
print(f"Error processing utterance: {e}")
emit('error', {'message': f'Error: {str(e)}'}, room=session_id)
finally:
# Cleanup temp file
if os.path.exists(temp_audio_path):
os.remove(temp_audio_path)
# Clear processing flag
if session_id in user_sessions:
session['is_processing'] = False
def transcribe_with_whisper(audio_path):
"""Transcribe audio using Faster-Whisper"""
@@ -375,49 +485,58 @@ def transcribe_with_whisper(audio_path):
# Collect all text from segments
user_text = ""
for segment in segments:
segment_text = segment.text.strip()
print(f"[{segment.start:.2f}s -> {segment.end:.2f}s] {segment_text}")
user_text += segment_text + " "
print(f"Transcribed text: {user_text.strip()}")
user_text += segment.text.strip() + " "
return user_text.strip()
def transcribe_with_google(audio_path):
def transcribe_fallback(audio_path):
"""Fallback transcription using Google's speech recognition"""
import speech_recognition as sr
recognizer = sr.Recognizer()
with sr.AudioFile(audio_path) as source:
audio = recognizer.record(source)
try:
text = recognizer.recognize_google(audio)
return text
except sr.UnknownValueError:
return ""
except sr.RequestError:
# If Google API fails, try a basic energy-based VAD approach
# This is a very basic fallback and won't give good results
return "[Speech detected but transcription failed]"
try:
import speech_recognition as sr
recognizer = sr.Recognizer()
with sr.AudioFile(audio_path) as source:
audio = recognizer.record(source)
try:
text = recognizer.recognize_google(audio)
return text
except sr.UnknownValueError:
return ""
except sr.RequestError:
return "[Speech recognition service unavailable]"
except ImportError:
return "[Speech recognition not available]"
def generate_llm_response(user_text, conversation_segments):
"""Generate text response using available model"""
def generate_ai_response(user_text, session_id):
"""Generate text response using available LLM"""
if session_id not in user_sessions:
return "I'm sorry, your session has expired."
session = user_sessions[session_id]
if llm_model is not None and llm_tokenizer is not None:
# Format conversation history for the LLM
conversation_history = ""
for segment in conversation_segments[-5:]: # Use last 5 utterances for context
speaker_name = "User" if segment.speaker == 0 else "Assistant"
conversation_history += f"{speaker_name}: {segment.text}\n"
prompt = "You are a helpful, friendly voice assistant. Keep your responses brief and conversational.\n\n"
# Add the current user query
conversation_history += f"User: {user_text}\nAssistant:"
# Add recent conversation history (last 6 turns maximum)
for entry in session['conversation_history'][-6:]:
if entry['role'] == 'user':
prompt += f"User: {entry['text']}\n"
else:
prompt += f"Assistant: {entry['text']}\n"
# Add current query if not already in history
if not session['conversation_history'] or session['conversation_history'][-1]['role'] != 'user':
prompt += f"User: {user_text}\n"
prompt += "Assistant: "
try:
# Generate response
inputs = llm_tokenizer(conversation_history, return_tensors="pt").to(device)
inputs = llm_tokenizer(prompt, return_tensors="pt").to(device)
output = llm_model.generate(
inputs.input_ids,
max_new_tokens=150,
max_new_tokens=100, # Keep responses shorter for voice
temperature=0.7,
top_p=0.9,
do_sample=True
@@ -425,43 +544,48 @@ def generate_llm_response(user_text, conversation_segments):
response = llm_tokenizer.decode(output[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
return response.strip()
except Exception as e:
print(f"Error generating response with LLM: {e}")
print(f"Error generating LLM response: {e}")
return fallback_response(user_text)
else:
return fallback_response(user_text)
def fallback_response(user_text):
"""Generate a simple fallback response when LLM is not available"""
# Simple rule-based responses
"""Generate simple fallback responses when LLM is unavailable"""
user_text_lower = user_text.lower()
if "hello" in user_text_lower or "hi" in user_text_lower:
return "Hello! I'm a simple fallback assistant. The main language model couldn't be loaded, so I have limited capabilities."
return "Hello! How can I help you today?"
elif "how are you" in user_text_lower:
return "I'm functioning within my limited capabilities. How can I assist you today?"
return "I'm doing well, thanks for asking! How about you?"
elif "thank" in user_text_lower:
return "You're welcome! Let me know if there's anything else I can help with."
return "You're welcome! Happy to help."
elif "bye" in user_text_lower or "goodbye" in user_text_lower:
return "Goodbye! Have a great day!"
elif any(q in user_text_lower for q in ["what", "who", "where", "when", "why", "how"]):
return "I'm running in fallback mode and can't answer complex questions. Please try again when the main language model is available."
return "That's an interesting question. I wish I could provide a better answer in my current fallback mode."
else:
return "I understand you said something about that. Unfortunately, I'm running in fallback mode with limited capabilities. Please try again later when the main model is available."
return "I see. Tell me more about that."
def generate_and_stream_audio_realtime(text, conversation_segments, session_id):
"""Generate audio response using CSM and stream it in real-time to client"""
if session_id not in active_audio_streams or not active_audio_streams[session_id]['active']:
def stream_ai_response(text, session_id):
"""Generate and stream audio response in real-time chunks"""
if session_id not in user_sessions:
return
session = user_sessions[session_id]
try:
# Use the last few conversation segments as context
context_segments = conversation_segments[-4:] if len(conversation_segments) > 4 else conversation_segments
# Signal start of AI speech
emit('ai_speech_start', room=session_id)
# Use the last few conversation segments as context (up to 4)
context_segments = session['segments'][-4:] if len(session['segments']) > 4 else session['segments']
# Generate audio for bot response
audio = csm_generator.generate(
@@ -473,23 +597,26 @@ def generate_and_stream_audio_realtime(text, conversation_segments, session_id):
topk=50
)
# Store the full audio for conversation history
# Create and store bot segment
bot_segment = Segment(
text=text,
speaker=1, # Bot is speaker 1
speaker=1,
audio=audio
)
if session_id in conversation_context and conversation_context[session_id]['active_session']:
conversation_context[session_id]['segments'].append(bot_segment)
if session_id in user_sessions:
session['segments'].append(bot_segment)
# Stream audio in small chunks for more responsive playback
chunk_size = 4800 # 200ms at 24kHz
chunk_size = AUDIO_CHUNK_SIZE # Size defined in constants
for i in range(0, len(audio), chunk_size):
if session_id not in active_audio_streams or not active_audio_streams[session_id]['active']:
print("Audio streaming interrupted or session ended")
# Check if we should stop (user interrupted)
if session_id not in user_sessions or session['should_interrupt_ai']:
print("AI speech interrupted")
break
# Get next chunk
chunk = audio[i:i+chunk_size]
# Convert audio chunk to base64 for streaming
@@ -499,32 +626,48 @@ def generate_and_stream_audio_realtime(text, conversation_segments, session_id):
audio_b64 = base64.b64encode(audio_bytes.read()).decode('utf-8')
# Send chunk to client
socketio.emit('ai_stream_data', {
socketio.emit('ai_speech_chunk', {
'audio': audio_b64,
'is_last': i + chunk_size >= len(audio)
}, room=session_id)
# Simulate real-time speech by adding a small delay
# Remove this in production for faster response
time.sleep(0.15) # Slight delay for more natural timing
# Small sleep for more natural pacing
time.sleep(0.06) # Slight delay for smoother playback
# Signal end of stream
if session_id in active_audio_streams and active_audio_streams[session_id]['active']:
socketio.emit('ai_stream_end', {}, room=session_id)
active_audio_streams[session_id]['active'] = False
# Signal end of AI speech
if session_id in user_sessions:
session['is_ai_speaking'] = False
session['is_turn_active'] = False # End conversation turn
socketio.emit('ai_speech_end', room=session_id)
except Exception as e:
print(f"Error generating or streaming audio: {e}")
# Send error message to client
if session_id in conversation_context and conversation_context[session_id]['active_session']:
socketio.emit('error', {
'message': f'Error generating audio: {str(e)}'
}, room=session_id)
# Signal stream end to unblock client
socketio.emit('ai_stream_end', {}, room=session_id)
if session_id in active_audio_streams:
active_audio_streams[session_id]['active'] = False
print(f"Error streaming AI response: {e}")
if session_id in user_sessions:
session['is_ai_speaking'] = False
session['is_turn_active'] = False
socketio.emit('error', {'message': f'Error generating audio: {str(e)}'}, room=session_id)
socketio.emit('ai_speech_end', room=session_id)
@socketio.on('interrupt_ai')
def handle_interrupt():
"""Handle explicit AI interruption request from client"""
session_id = request.sid
if session_id in user_sessions:
user_sessions[session_id]['should_interrupt_ai'] = True
emit('ai_interrupted', room=session_id)
@socketio.on('get_config')
def handle_get_config():
"""Send configuration to client"""
session_id = request.sid
if session_id in user_sessions:
emit('config', {
'client_sample_rate': CLIENT_SAMPLE_RATE,
'server_sample_rate': getattr(csm_generator, 'sample_rate', 24000) if csm_generator else 24000,
'whisper_available': whisper_model is not None,
'csm_available': csm_generator is not None,
'ice_servers': ICE_SERVERS
})
if __name__ == '__main__':
# Ensure the existing index.html file is in the correct location
@@ -538,6 +681,6 @@ if __name__ == '__main__':
print("Starting model loading...")
load_models()
# Start the server with eventlet for better WebSocket performance
# Start the server
print("Starting Flask SocketIO server...")
socketio.run(app, host='0.0.0.0', port=5000, debug=False)