Demo Fixes 14

This commit is contained in:
2025-03-30 08:36:50 -04:00
parent a55b3f52a4
commit 12383d5e8b
2 changed files with 475 additions and 33 deletions

View File

@@ -175,12 +175,124 @@
margin-top: 4px; margin-top: 4px;
text-align: right; text-align: right;
} }
.text-only-indicator {
font-size: 0.8em;
color: #e74c3c;
margin-top: 4px;
font-style: italic;
}
.status-message {
text-align: center;
padding: 8px;
margin: 10px 0;
background-color: #f8f9fa;
border-radius: 5px;
color: #666;
font-size: 0.9em;
}
/* Audio visualizer styles */
.visualizer-container {
width: 100%;
height: 120px;
margin: 15px 0;
border-radius: 10px;
overflow: hidden;
background-color: #000;
position: relative;
}
#visualizer {
width: 100%;
height: 100%;
display: block;
}
.visualizer-label {
position: absolute;
top: 10px;
left: 10px;
color: white;
font-size: 0.8em;
background-color: rgba(0, 0, 0, 0.5);
padding: 4px 8px;
border-radius: 4px;
}
/* Modern switch for visualizer toggle */
.switch-container {
display: flex;
align-items: center;
justify-content: center;
margin-bottom: 10px;
}
.switch {
position: relative;
display: inline-block;
width: 50px;
height: 24px;
margin-left: 10px;
}
.switch input {
opacity: 0;
width: 0;
height: 0;
}
.slider {
position: absolute;
cursor: pointer;
top: 0;
left: 0;
right: 0;
bottom: 0;
background-color: #ccc;
transition: .4s;
border-radius: 24px;
}
.slider:before {
position: absolute;
content: "";
height: 16px;
width: 16px;
left: 4px;
bottom: 4px;
background-color: white;
transition: .4s;
border-radius: 50%;
}
input:checked + .slider {
background-color: #4CAF50;
}
input:checked + .slider:before {
transform: translateX(26px);
}
</style> </style>
</head> </head>
<body> <body>
<h1>Voice Assistant with CSM & Whisper</h1> <h1>Voice Assistant with CSM & Whisper</h1>
<div id="conversation"></div> <div id="conversation"></div>
<div class="switch-container">
<span>Audio Visualizer</span>
<label class="switch">
<input type="checkbox" id="visualizerToggle" checked>
<span class="slider"></span>
</label>
</div>
<div class="visualizer-container" id="visualizerContainer">
<canvas id="visualizer"></canvas>
<div class="visualizer-label" id="visualizerLabel">Listening...</div>
</div>
<div id="controls"> <div id="controls">
<button id="recordButton">Hold to Speak</button> <button id="recordButton">Hold to Speak</button>
</div> </div>
@@ -201,27 +313,80 @@
const conversation = document.getElementById('conversation'); const conversation = document.getElementById('conversation');
const status = document.getElementById('status'); const status = document.getElementById('status');
const audioWave = document.getElementById('audioWave'); const audioWave = document.getElementById('audioWave');
const visualizerToggle = document.getElementById('visualizerToggle');
const visualizerContainer = document.getElementById('visualizerContainer');
const visualizerLabel = document.getElementById('visualizerLabel');
const canvas = document.getElementById('visualizer');
const canvasCtx = canvas.getContext('2d');
let mediaRecorder; let mediaRecorder;
let audioChunks = []; let audioChunks = [];
let isRecording = false; let isRecording = false;
let audioSendInterval; let audioSendInterval;
let sessionActive = false; let sessionActive = false;
let reconnectAttempts = 0;
let audioStream = null;
let audioAnalyser = null;
let visualizerActive = true;
let visualizerAnimationId = null;
let audioBufferSource = null;
// Initialize audio context // Initialize audio context
const audioContext = new (window.AudioContext || window.webkitAudioContext)(); const audioContext = new (window.AudioContext || window.webkitAudioContext)();
// Set up canvas size
function setupCanvas() {
canvas.width = visualizerContainer.offsetWidth;
canvas.height = visualizerContainer.offsetHeight;
}
// Handle visualizer toggle
visualizerToggle.addEventListener('change', function() {
visualizerActive = this.checked;
visualizerContainer.style.display = visualizerActive ? 'block' : 'none';
if (!visualizerActive && visualizerAnimationId) {
cancelAnimationFrame(visualizerAnimationId);
visualizerAnimationId = null;
} else if (visualizerActive && audioAnalyser) {
drawVisualizer();
}
});
// Connect to server // Connect to server
socket.on('connect', () => { socket.on('connect', () => {
status.textContent = 'Connected to server'; status.textContent = 'Connected to server';
sessionActive = true; sessionActive = true;
reconnectAttempts = 0;
if (conversation.children.length > 0) {
addStatusMessage("Reconnected to server");
}
}); });
socket.on('disconnect', () => { socket.on('disconnect', () => {
status.textContent = 'Disconnected from server'; status.textContent = 'Disconnected from server';
sessionActive = false; sessionActive = false;
addStatusMessage("Disconnected from server. Attempting to reconnect...");
// Attempt to reconnect
tryReconnect();
}); });
function tryReconnect() {
if (reconnectAttempts < 5) {
reconnectAttempts++;
setTimeout(() => {
if (!sessionActive) {
socket.connect();
}
}, 1000 * reconnectAttempts);
} else {
addStatusMessage("Failed to reconnect. Please refresh the page.");
}
}
socket.on('ready', (data) => { socket.on('ready', (data) => {
status.textContent = data.message; status.textContent = data.message;
setupAudioRecording(); setupAudioRecording();
@@ -230,34 +395,87 @@
socket.on('transcription', (data) => { socket.on('transcription', (data) => {
addMessage('user', data.text); addMessage('user', data.text);
status.textContent = 'Assistant is thinking...'; status.textContent = 'Assistant is thinking...';
visualizerLabel.textContent = 'Processing...';
}); });
socket.on('audio_response', (data) => { socket.on('audio_response', (data) => {
// Play audio // Play audio
status.textContent = 'Playing response...'; status.textContent = 'Playing response...';
visualizerLabel.textContent = 'Assistant speaking...';
// Create audio element
const audio = new Audio('data:audio/wav;base64,' + data.audio); const audio = new Audio('data:audio/wav;base64,' + data.audio);
// Visualize assistant audio if visualizer is active
if (visualizerActive) {
visualizeResponseAudio(audio);
}
audio.onended = () => { audio.onended = () => {
status.textContent = 'Ready to record'; status.textContent = 'Ready to record';
visualizerLabel.textContent = 'Listening...';
if (audioBufferSource) {
audioBufferSource.disconnect();
audioBufferSource = null;
}
}; };
audio.onerror = () => { audio.onerror = () => {
status.textContent = 'Error playing audio'; status.textContent = 'Error playing audio';
visualizerLabel.textContent = 'Listening...';
console.error('Error playing audio response'); console.error('Error playing audio response');
}; };
audio.play().catch(err => { audio.play().catch(err => {
status.textContent = 'Error playing audio: ' + err.message; status.textContent = 'Error playing audio: ' + err.message;
visualizerLabel.textContent = 'Listening...';
console.error('Error playing audio:', err); console.error('Error playing audio:', err);
}); });
// Display text // Display text
addMessage('bot', data.text); addMessage('bot', data.text, false);
});
// Visualize response audio
async function visualizeResponseAudio(audioElement) {
try {
// Create media element source
const audioSource = audioContext.createMediaElementSource(audioElement);
// Create analyser
const analyser = audioContext.createAnalyser();
analyser.fftSize = 2048;
// Connect
audioSource.connect(analyser);
analyser.connect(audioContext.destination);
// Store reference
audioAnalyser = analyser;
// Start visualization
drawVisualizer();
} catch (e) {
console.error('Error setting up audio visualization:', e);
}
}
// Handle text-only responses when audio generation isn't available
socket.on('text_response', (data) => {
status.textContent = 'Received text response';
visualizerLabel.textContent = 'Text only (no audio)';
addMessage('bot', data.text, true);
setTimeout(() => {
status.textContent = 'Ready to record';
visualizerLabel.textContent = 'Listening...';
}, 1000);
}); });
socket.on('error', (data) => { socket.on('error', (data) => {
status.textContent = 'Error: ' + data.message; status.textContent = 'Error: ' + data.message;
visualizerLabel.textContent = 'Error occurred';
console.error('Server error:', data.message); console.error('Server error:', data.message);
addStatusMessage("Error: " + data.message);
}); });
function setupAudioRecording() { function setupAudioRecording() {
@@ -267,9 +485,29 @@
return; return;
} }
// Set up canvas
setupCanvas();
// Get user media // Get user media
navigator.mediaDevices.getUserMedia({ audio: true }) navigator.mediaDevices.getUserMedia({ audio: true })
.then(stream => { .then(stream => {
// Store stream for visualizer
audioStream = stream;
// Create audio analyser for visualization
const source = audioContext.createMediaStreamSource(stream);
const analyser = audioContext.createAnalyser();
analyser.fftSize = 2048;
source.connect(analyser);
// Store analyser for visualization
audioAnalyser = analyser;
// Start visualizer if enabled
if (visualizerActive) {
drawVisualizer();
}
// Setup recording with better audio quality // Setup recording with better audio quality
const options = { const options = {
mimeType: 'audio/webm', mimeType: 'audio/webm',
@@ -293,12 +531,6 @@
processRecording(); processRecording();
}; };
// Create audio analyzer for visualization
const source = audioContext.createMediaStreamSource(stream);
const analyzer = audioContext.createAnalyser();
analyzer.fftSize = 2048;
source.connect(analyzer);
// Setup button handlers with better touch handling // Setup button handlers with better touch handling
recordButton.addEventListener('mousedown', startRecording); recordButton.addEventListener('mousedown', startRecording);
recordButton.addEventListener('touchstart', (e) => { recordButton.addEventListener('touchstart', (e) => {
@@ -322,6 +554,57 @@
}); });
} }
// Draw visualizer animation
function drawVisualizer() {
if (!visualizerActive || !audioAnalyser) {
return;
}
visualizerAnimationId = requestAnimationFrame(drawVisualizer);
const bufferLength = audioAnalyser.frequencyBinCount;
const dataArray = new Uint8Array(bufferLength);
// Get frequency data
audioAnalyser.getByteFrequencyData(dataArray);
// Clear canvas
canvasCtx.fillStyle = '#000';
canvasCtx.fillRect(0, 0, canvas.width, canvas.height);
// Draw visualization based on audio data
const barWidth = (canvas.width / bufferLength) * 2.5;
let x = 0;
// Choose color based on state
let gradient;
if (isRecording) {
// Red gradient for recording
gradient = canvasCtx.createLinearGradient(0, 0, 0, canvas.height);
gradient.addColorStop(0, 'rgba(255, 0, 0, 0.8)');
gradient.addColorStop(1, 'rgba(255, 80, 80, 0.2)');
} else if (visualizerLabel.textContent === 'Assistant speaking...') {
// Blue gradient for assistant
gradient = canvasCtx.createLinearGradient(0, 0, 0, canvas.height);
gradient.addColorStop(0, 'rgba(0, 120, 255, 0.8)');
gradient.addColorStop(1, 'rgba(80, 160, 255, 0.2)');
} else {
// Green gradient for listening
gradient = canvasCtx.createLinearGradient(0, 0, 0, canvas.height);
gradient.addColorStop(0, 'rgba(0, 200, 80, 0.8)');
gradient.addColorStop(1, 'rgba(80, 255, 120, 0.2)');
}
for (let i = 0; i < bufferLength; i++) {
const barHeight = (dataArray[i] / 255) * canvas.height;
canvasCtx.fillStyle = gradient;
canvasCtx.fillRect(x, canvas.height - barHeight, barWidth, barHeight);
x += barWidth + 1;
}
}
function startRecording() { function startRecording() {
if (!isRecording && sessionActive) { if (!isRecording && sessionActive) {
audioChunks = []; audioChunks = [];
@@ -329,6 +612,7 @@
recordButton.classList.add('recording'); recordButton.classList.add('recording');
recordButton.textContent = 'Release to Stop'; recordButton.textContent = 'Release to Stop';
status.textContent = 'Recording...'; status.textContent = 'Recording...';
visualizerLabel.textContent = 'Recording...';
audioWave.classList.remove('hidden'); audioWave.classList.remove('hidden');
isRecording = true; isRecording = true;
@@ -350,6 +634,7 @@
recordButton.classList.remove('recording'); recordButton.classList.remove('recording');
recordButton.textContent = 'Hold to Speak'; recordButton.textContent = 'Hold to Speak';
status.textContent = 'Processing speech...'; status.textContent = 'Processing speech...';
visualizerLabel.textContent = 'Processing...';
audioWave.classList.add('hidden'); audioWave.classList.add('hidden');
isRecording = false; isRecording = false;
} }
@@ -358,6 +643,7 @@
function processRecording() { function processRecording() {
if (audioChunks.length === 0) { if (audioChunks.length === 0) {
status.textContent = 'No audio recorded'; status.textContent = 'No audio recorded';
visualizerLabel.textContent = 'Listening...';
return; return;
} }
@@ -380,11 +666,13 @@
} catch (e) { } catch (e) {
console.error('Error processing audio:', e); console.error('Error processing audio:', e);
status.textContent = 'Error processing audio'; status.textContent = 'Error processing audio';
visualizerLabel.textContent = 'Error';
} }
}; };
fileReader.onerror = () => { fileReader.onerror = () => {
status.textContent = 'Error reading audio data'; status.textContent = 'Error reading audio data';
visualizerLabel.textContent = 'Error';
}; };
fileReader.readAsArrayBuffer(audioBlob); fileReader.readAsArrayBuffer(audioBlob);
@@ -403,7 +691,7 @@
return float32Array; return float32Array;
} }
function addMessage(sender, text) { function addMessage(sender, text, textOnly = false) {
const containerDiv = document.createElement('div'); const containerDiv = document.createElement('div');
containerDiv.className = sender === 'user' ? 'message-container user-message-container' : 'message-container bot-message-container'; containerDiv.className = sender === 'user' ? 'message-container user-message-container' : 'message-container bot-message-container';
@@ -422,12 +710,39 @@
infoDiv.className = 'transcription-info'; infoDiv.className = 'transcription-info';
infoDiv.textContent = 'Transcribed with Whisper'; infoDiv.textContent = 'Transcribed with Whisper';
containerDiv.appendChild(infoDiv); containerDiv.appendChild(infoDiv);
} else if (textOnly) {
// Add indicator for text-only response
const textOnlyDiv = document.createElement('div');
textOnlyDiv.className = 'text-only-indicator';
textOnlyDiv.textContent = 'Text-only response (audio unavailable)';
containerDiv.appendChild(textOnlyDiv);
} }
conversation.appendChild(containerDiv); conversation.appendChild(containerDiv);
conversation.scrollTop = conversation.scrollHeight; conversation.scrollTop = conversation.scrollHeight;
} }
function addStatusMessage(message) {
const statusDiv = document.createElement('div');
statusDiv.className = 'status-message';
statusDiv.textContent = message;
conversation.appendChild(statusDiv);
conversation.scrollTop = conversation.scrollHeight;
// Auto-remove status messages after 10 seconds
setTimeout(() => {
if (conversation.contains(statusDiv)) {
statusDiv.style.opacity = '0';
statusDiv.style.transition = 'opacity 0.5s';
setTimeout(() => {
if (conversation.contains(statusDiv)) {
conversation.removeChild(statusDiv);
}
}, 500);
}
}, 10000);
}
function arrayBufferToBase64(buffer) { function arrayBufferToBase64(buffer) {
let binary = ''; let binary = '';
const bytes = new Uint8Array(buffer); const bytes = new Uint8Array(buffer);
@@ -450,6 +765,34 @@
if (socket && socket.connected) { if (socket && socket.connected) {
socket.disconnect(); socket.disconnect();
} }
if (visualizerAnimationId) {
cancelAnimationFrame(visualizerAnimationId);
}
});
// Add a reload button for debugging
const reloadButton = document.createElement('button');
reloadButton.textContent = '🔄 Reload';
reloadButton.style.position = 'fixed';
reloadButton.style.bottom = '10px';
reloadButton.style.right = '10px';
reloadButton.style.padding = '5px 10px';
reloadButton.style.fontSize = '12px';
reloadButton.style.backgroundColor = '#f5f5f5';
reloadButton.style.border = '1px solid #ddd';
reloadButton.style.borderRadius = '4px';
reloadButton.style.cursor = 'pointer';
reloadButton.addEventListener('click', () => {
window.location.reload();
});
document.body.appendChild(reloadButton);
// Handle window resize to update canvas size
window.addEventListener('resize', () => {
setupCanvas();
}); });
</script> </script>
</body> </body>

View File

@@ -12,6 +12,10 @@ from collections import deque
import requests import requests
import huggingface_hub import huggingface_hub
from generator import load_csm_1b, Segment from generator import load_csm_1b, Segment
import threading
import queue
from flask import stream_with_context, Response
import time
# Configure environment with longer timeouts # Configure environment with longer timeouts
os.environ["HF_HUB_DOWNLOAD_TIMEOUT"] = "600" # 10 minutes timeout for downloads os.environ["HF_HUB_DOWNLOAD_TIMEOUT"] = "600" # 10 minutes timeout for downloads
@@ -124,6 +128,8 @@ def load_models():
# Store conversation context # Store conversation context
conversation_context = {} # session_id -> context conversation_context = {} # session_id -> context
CHUNK_SIZE = 24000 # Number of audio samples per chunk (1 second at 24kHz)
audio_stream_queues = {} # session_id -> queue for audio chunks
@app.route('/') @app.route('/')
def index(): def index():
@@ -144,8 +150,14 @@ def handle_connect():
@socketio.on('disconnect') @socketio.on('disconnect')
def handle_disconnect(): def handle_disconnect():
print(f"Client disconnected: {request.sid}") print(f"Client disconnected: {request.sid}")
if request.sid in conversation_context: session_id = request.sid
del conversation_context[request.sid]
# Clean up resources
if session_id in conversation_context:
del conversation_context[session_id]
if session_id in audio_stream_queues:
del audio_stream_queues[session_id]
@socketio.on('start_speaking') @socketio.on('start_speaking')
def handle_start_speaking(): def handle_start_speaking():
@@ -191,7 +203,7 @@ def is_silence(audio_tensor, threshold=0.02):
return torch.mean(torch.abs(audio_tensor)) < threshold return torch.mean(torch.abs(audio_tensor)) < threshold
def process_user_utterance(session_id): def process_user_utterance(session_id):
"""Process completed user utterance, generate response and send audio back""" """Process completed user utterance, generate response and stream audio back"""
context = conversation_context[session_id] context = conversation_context[session_id]
if not context['audio_buffer']: if not context['audio_buffer']:
@@ -234,37 +246,32 @@ def process_user_utterance(session_id):
) )
context['segments'].append(user_segment) context['segments'].append(user_segment)
# Generate bot response # Generate bot response text
bot_response = generate_llm_response(user_text, context['segments']) bot_response = generate_llm_response(user_text, context['segments'])
print(f"Bot response: {bot_response}") print(f"Bot response: {bot_response}")
# Send transcribed text to client # Send transcribed text to client
emit('transcription', {'text': user_text}, room=session_id) emit('transcription', {'text': user_text}, room=session_id)
# Generate and send audio response if CSM is available # Generate and stream audio response if CSM is available
if csm_generator is not None: if csm_generator is not None:
# Convert to audio using CSM # Set up streaming queue for this session
bot_audio = generate_audio_response(bot_response, context['segments']) if session_id not in audio_stream_queues:
audio_stream_queues[session_id] = queue.Queue()
else:
# Clear any existing items in the queue
while not audio_stream_queues[session_id].empty():
audio_stream_queues[session_id].get()
# Convert audio to base64 for sending over websocket # Start audio generation in a separate thread to not block the server
audio_bytes = io.BytesIO() threading.Thread(
torchaudio.save(audio_bytes, bot_audio.unsqueeze(0).cpu(), csm_generator.sample_rate, format="wav") target=generate_and_stream_audio,
audio_bytes.seek(0) args=(bot_response, context['segments'], session_id),
audio_b64 = base64.b64encode(audio_bytes.read()).decode('utf-8') daemon=True
).start()
# Add bot response to conversation history # Initial response with text
bot_segment = Segment( emit('start_streaming_response', {'text': bot_response}, room=session_id)
text=bot_response,
speaker=1, # Bot is speaker 1
audio=bot_audio
)
context['segments'].append(bot_segment)
# Send audio response to client
emit('audio_response', {
'audio': audio_b64,
'text': bot_response
}, room=session_id)
else: else:
# Send text-only response if audio generation isn't available # Send text-only response if audio generation isn't available
emit('text_response', {'text': bot_response}, room=session_id) emit('text_response', {'text': bot_response}, room=session_id)
@@ -391,6 +398,98 @@ def generate_audio_response(text, conversation_segments):
# Return silence as fallback # Return silence as fallback
return torch.zeros(csm_generator.sample_rate * 3) # 3 seconds of silence return torch.zeros(csm_generator.sample_rate * 3) # 3 seconds of silence
def generate_and_stream_audio(text, conversation_segments, session_id):
"""Generate audio response using CSM and stream it in chunks"""
try:
# Use the last few conversation segments as context
context_segments = conversation_segments[-4:] if len(conversation_segments) > 4 else conversation_segments
# Generate full audio for bot response
audio = csm_generator.generate(
text=text,
speaker=1, # Bot is speaker 1
context=context_segments,
max_audio_length_ms=10000, # 10 seconds max
temperature=0.9,
topk=50
)
# Store the full audio for conversation history
bot_segment = Segment(
text=text,
speaker=1, # Bot is speaker 1
audio=audio
)
if session_id in conversation_context:
conversation_context[session_id]['segments'].append(bot_segment)
# Split audio into chunks for streaming
chunk_size = CHUNK_SIZE
for i in range(0, len(audio), chunk_size):
chunk = audio[i:i+chunk_size]
# Convert audio chunk to base64 for streaming
audio_bytes = io.BytesIO()
torchaudio.save(audio_bytes, chunk.unsqueeze(0).cpu(), csm_generator.sample_rate, format="wav")
audio_bytes.seek(0)
audio_b64 = base64.b64encode(audio_bytes.read()).decode('utf-8')
# Send the chunk to the client
if session_id in audio_stream_queues:
audio_stream_queues[session_id].put({
'audio': audio_b64,
'is_last': i + chunk_size >= len(audio)
})
else:
# Session was disconnected before we finished generating
break
# Signal the end of streaming if queue still exists
if session_id in audio_stream_queues:
# Add an empty chunk as a sentinel to signal end of streaming
audio_stream_queues[session_id].put(None)
except Exception as e:
print(f"Error generating or streaming audio: {e}")
# Send error message to client
if session_id in conversation_context:
socketio.emit('error', {
'message': f'Error generating audio: {str(e)}'
}, room=session_id)
# Send a final message to unblock the client
if session_id in audio_stream_queues:
audio_stream_queues[session_id].put(None)
@socketio.on('request_audio_chunk')
def handle_request_audio_chunk():
"""Send the next audio chunk in the queue to the client"""
session_id = request.sid
if session_id not in audio_stream_queues:
emit('error', {'message': 'No audio stream available'})
return
# Get the next chunk or wait for it to be available
try:
if not audio_stream_queues[session_id].empty():
chunk = audio_stream_queues[session_id].get(block=False)
# If chunk is None, we're done streaming
if chunk is None:
emit('end_streaming')
# Clean up the queue
if session_id in audio_stream_queues:
del audio_stream_queues[session_id]
else:
emit('audio_chunk', chunk)
else:
# If the queue is empty but we're still generating, tell client to wait
emit('wait_for_chunk')
except Exception as e:
print(f"Error sending audio chunk: {e}")
emit('error', {'message': f'Error streaming audio: {str(e)}'})
if __name__ == '__main__': if __name__ == '__main__':
# Ensure the existing index.html file is in the correct location # Ensure the existing index.html file is in the correct location
if not os.path.exists('templates'): if not os.path.exists('templates'):