Backend Server Update

This commit is contained in:
2025-03-29 22:06:00 -04:00
parent e8a9207da4
commit 06fa7936a3
3 changed files with 360 additions and 284 deletions

1
.gitignore vendored
View File

@@ -134,3 +134,4 @@ dist
.yarn/build-state.yml
.yarn/install-state.gz
.pnp.*
Backend/test.py

View File

@@ -10,60 +10,113 @@
max-width: 800px;
margin: 0 auto;
padding: 20px;
background-color: #f9f9f9;
}
.conversation {
border: 1px solid #ccc;
border-radius: 8px;
padding: 15px;
height: 300px;
border: 1px solid #ddd;
border-radius: 12px;
padding: 20px;
height: 400px;
overflow-y: auto;
margin-bottom: 15px;
margin-bottom: 20px;
background-color: white;
box-shadow: 0 2px 10px rgba(0,0,0,0.05);
}
.message {
margin-bottom: 10px;
padding: 8px;
border-radius: 8px;
margin-bottom: 15px;
padding: 12px;
border-radius: 12px;
max-width: 80%;
line-height: 1.4;
}
.user {
background-color: #e3f2fd;
text-align: right;
margin-left: auto;
border-bottom-right-radius: 4px;
}
.ai {
background-color: #f1f1f1;
margin-right: auto;
border-bottom-left-radius: 4px;
}
.system {
background-color: #f8f9fa;
font-style: italic;
text-align: center;
font-size: 0.9em;
color: #666;
padding: 8px;
margin: 10px auto;
max-width: 90%;
}
.controls {
display: flex;
flex-direction: column;
gap: 10px;
}
.input-row {
display: flex;
gap: 10px;
}
input[type="text"] {
flex-grow: 1;
padding: 8px;
border-radius: 4px;
border: 1px solid #ccc;
gap: 15px;
justify-content: center;
align-items: center;
}
button {
padding: 8px 16px;
border-radius: 4px;
padding: 12px 24px;
border-radius: 24px;
border: none;
background-color: #4CAF50;
color: white;
cursor: pointer;
font-weight: bold;
transition: all 0.2s ease;
box-shadow: 0 2px 5px rgba(0,0,0,0.1);
}
button:hover {
background-color: #45a049;
box-shadow: 0 4px 8px rgba(0,0,0,0.15);
}
.recording {
background-color: #f44336;
animation: pulse 1.5s infinite;
}
.processing {
background-color: #FFA500;
}
select {
padding: 8px;
border-radius: 4px;
border: 1px solid #ccc;
padding: 10px;
border-radius: 24px;
border: 1px solid #ddd;
background-color: white;
}
.transcript {
font-style: italic;
color: #666;
margin-top: 5px;
}
@keyframes pulse {
0% { opacity: 1; }
50% { opacity: 0.7; }
100% { opacity: 1; }
}
.status-indicator {
display: flex;
align-items: center;
justify-content: center;
margin-top: 10px;
gap: 5px;
}
.status-dot {
width: 10px;
height: 10px;
border-radius: 50%;
background-color: #ccc;
}
.status-dot.active {
background-color: #4CAF50;
}
.status-text {
font-size: 0.9em;
color: #666;
}
audio {
width: 100%;
margin-top: 5px;
}
</style>
</head>
@@ -72,30 +125,25 @@
<div class="conversation" id="conversation"></div>
<div class="controls">
<div class="input-row">
<input type="text" id="textInput" placeholder="Type your message...">
<select id="speakerSelect">
<option value="0">Speaker 0</option>
<option value="1">Speaker 1</option>
</select>
<button id="sendText">Send</button>
</div>
<div class="input-row">
<button id="recordAudio">Record Audio</button>
<button id="clearContext">Clear Context</button>
</div>
<select id="speakerSelect">
<option value="0">Speaker 0</option>
<option value="1">Speaker 1</option>
</select>
<button id="streamButton">Start Conversation</button>
<button id="clearButton">Clear Chat</button>
</div>
<div class="status-indicator">
<div class="status-dot" id="statusDot"></div>
<div class="status-text" id="statusText">Not connected</div>
</div>
<script>
// Variables
let ws;
let mediaRecorder;
let audioChunks = [];
let isRecording = false;
let audioContext;
let streamProcessor;
let isStreaming = false;
let streamButton;
let isSpeaking = false;
let silenceTimer = null;
let energyWindow = [];
@@ -105,24 +153,20 @@
// DOM elements
const conversationEl = document.getElementById('conversation');
const textInputEl = document.getElementById('textInput');
const speakerSelectEl = document.getElementById('speakerSelect');
const sendTextBtn = document.getElementById('sendText');
const recordAudioBtn = document.getElementById('recordAudio');
const clearContextBtn = document.getElementById('clearContext');
const streamButton = document.getElementById('streamButton');
const clearButton = document.getElementById('clearButton');
const statusDot = document.getElementById('statusDot');
const statusText = document.getElementById('statusText');
// Add streaming button to the input row
// Initialize on page load
window.addEventListener('load', () => {
const inputRow = document.querySelector('.input-row:nth-child(2)');
streamButton = document.createElement('button');
streamButton.id = 'streamAudio';
streamButton.textContent = 'Start Streaming';
streamButton.addEventListener('click', toggleStreaming);
inputRow.appendChild(streamButton);
connectWebSocket();
setupRecording();
setupAudioContext();
// Event listeners
streamButton.addEventListener('click', toggleStreaming);
clearButton.addEventListener('click', clearConversation);
});
// Setup audio context for streaming
@@ -136,8 +180,68 @@
}
}
// Toggle audio streaming
async function toggleStreaming() {
// Connect to WebSocket server
function connectWebSocket() {
const wsProtocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:';
const wsUrl = `${wsProtocol}//${window.location.hostname}:8000/ws`;
ws = new WebSocket(wsUrl);
ws.onopen = () => {
console.log('WebSocket connected');
statusDot.classList.add('active');
statusText.textContent = 'Connected';
addSystemMessage('Connected to server');
};
ws.onmessage = (event) => {
const response = JSON.parse(event.data);
console.log('Received:', response);
if (response.type === 'audio_response') {
// Play audio response
const audio = new Audio(response.audio);
audio.play();
// Add message to conversation
addAIMessage(response.text || 'AI response', response.audio);
// Reset to speaking state after AI response
if (isStreaming) {
streamButton.textContent = 'Listening...';
streamButton.style.backgroundColor = '#f44336'; // Back to red
streamButton.classList.add('recording');
isSpeaking = false; // Reset speaking state
}
} else if (response.type === 'error') {
addSystemMessage(`Error: ${response.message}`);
} else if (response.type === 'context_updated') {
addSystemMessage(response.message);
} else if (response.type === 'streaming_status') {
addSystemMessage(`Streaming ${response.status}`);
} else if (response.type === 'transcription') {
addUserTranscription(response.text);
}
};
ws.onclose = () => {
console.log('WebSocket disconnected');
statusDot.classList.remove('active');
statusText.textContent = 'Disconnected';
addSystemMessage('Disconnected from server. Reconnecting...');
setTimeout(connectWebSocket, 3000);
};
ws.onerror = (error) => {
console.error('WebSocket error:', error);
statusDot.classList.remove('active');
statusText.textContent = 'Error';
addSystemMessage('Connection error');
};
}
// Toggle streaming
function toggleStreaming() {
if (isStreaming) {
stopStreaming();
} else {
@@ -145,7 +249,7 @@
}
}
// Start audio streaming with silence detection
// Start streaming
async function startStreaming() {
try {
const stream = await navigator.mediaDevices.getUserMedia({ audio: true });
@@ -155,7 +259,7 @@
isSpeaking = false;
energyWindow = [];
streamButton.textContent = 'Speaking...';
streamButton.textContent = 'Listening...';
streamButton.classList.add('recording');
// Create audio processor node
@@ -186,13 +290,13 @@
source.connect(streamProcessor);
streamProcessor.connect(audioContext.destination);
addSystemMessage('Audio streaming started - speak naturally and pause when finished');
addSystemMessage('Listening - speak naturally and pause when finished');
} catch (err) {
console.error('Error starting audio stream:', err);
addSystemMessage(`Streaming error: ${err.message}`);
addSystemMessage(`Microphone error: ${err.message}`);
isStreaming = false;
streamButton.textContent = 'Start Streaming';
streamButton.textContent = 'Start Conversation';
streamButton.classList.remove('recording');
}
}
@@ -228,15 +332,17 @@
silenceTimer = setTimeout(() => {
// Silence persisted long enough
streamButton.textContent = 'Processing...';
streamButton.style.backgroundColor = '#FFA500'; // Orange
streamButton.classList.remove('recording');
streamButton.classList.add('processing');
addSystemMessage('Detected pause in speech, processing response...');
}, CLIENT_SILENCE_DURATION_MS);
}
} else if (!isSpeaking && !isSilent) {
// Transition from silence to speaking
isSpeaking = true;
streamButton.textContent = 'Speaking...';
streamButton.style.backgroundColor = '#f44336'; // Red
streamButton.textContent = 'Listening...';
streamButton.classList.add('recording');
streamButton.classList.remove('processing');
// Clear any pending silence timer
if (silenceTimer) {
@@ -276,7 +382,7 @@
reader.readAsDataURL(wavData);
}
// Stop audio streaming
// Stop streaming
function stopStreaming() {
if (streamProcessor) {
streamProcessor.disconnect();
@@ -293,11 +399,11 @@
isSpeaking = false;
energyWindow = [];
streamButton.textContent = 'Start Streaming';
streamButton.classList.remove('recording');
streamButton.textContent = 'Start Conversation';
streamButton.classList.remove('recording', 'processing');
streamButton.style.backgroundColor = ''; // Reset to default
addSystemMessage('Audio streaming stopped');
addSystemMessage('Conversation paused');
// Send stop streaming signal to server
ws.send(JSON.stringify({
@@ -306,6 +412,18 @@
}));
}
// Clear conversation
function clearConversation() {
// Clear conversation history
ws.send(JSON.stringify({
action: 'clear_context'
}));
// Clear the UI
conversationEl.innerHTML = '';
addSystemMessage('Conversation cleared');
}
// Downsample audio buffer to target sample rate
function downsampleBuffer(buffer, sampleRate, targetSampleRate) {
if (targetSampleRate === sampleRate) {
@@ -376,212 +494,49 @@
}
}
// Connect to WebSocket
function connectWebSocket() {
const wsProtocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:';
const wsUrl = `${wsProtocol}//${window.location.hostname}:8000/ws`;
// Message display functions
function addUserTranscription(text) {
// Find if there's already a pending user message
let pendingMessage = document.querySelector('.message.user.pending');
ws = new WebSocket(wsUrl);
if (!pendingMessage) {
// Create a new message
pendingMessage = document.createElement('div');
pendingMessage.classList.add('message', 'user', 'pending');
conversationEl.appendChild(pendingMessage);
}
ws.onopen = () => {
console.log('WebSocket connected');
addSystemMessage('Connected to server');
};
ws.onmessage = (event) => {
const response = JSON.parse(event.data);
console.log('Received:', response);
if (response.type === 'audio_response') {
// Play audio response
const audio = new Audio(response.audio);
audio.play();
// Add message to conversation
addAIMessage(response.audio);
// Reset the streaming button if we're still in streaming mode
if (isStreaming) {
streamButton.textContent = 'Speaking...';
streamButton.style.backgroundColor = '#f44336'; // Back to red
isSpeaking = false; // Reset speaking state
}
} else if (response.type === 'error') {
addSystemMessage(`Error: ${response.message}`);
} else if (response.type === 'context_updated') {
addSystemMessage(response.message);
} else if (response.type === 'streaming_status') {
addSystemMessage(`Streaming ${response.status}`);
}
};
ws.onclose = () => {
console.log('WebSocket disconnected');
addSystemMessage('Disconnected from server. Reconnecting...');
setTimeout(connectWebSocket, 3000);
};
ws.onerror = (error) => {
console.error('WebSocket error:', error);
addSystemMessage('Connection error');
};
}
// Add message to conversation
function addUserMessage(text) {
const messageEl = document.createElement('div');
messageEl.classList.add('message', 'user');
messageEl.textContent = text;
conversationEl.appendChild(messageEl);
pendingMessage.textContent = text;
pendingMessage.classList.remove('pending');
conversationEl.scrollTop = conversationEl.scrollHeight;
}
function addAIMessage(audioSrc) {
function addAIMessage(text, audioSrc) {
const messageEl = document.createElement('div');
messageEl.classList.add('message', 'ai');
if (text) {
const textDiv = document.createElement('div');
textDiv.textContent = text;
messageEl.appendChild(textDiv);
}
const audioEl = document.createElement('audio');
audioEl.controls = true;
audioEl.src = audioSrc;
messageEl.appendChild(audioEl);
conversationEl.appendChild(messageEl);
conversationEl.scrollTop = conversationEl.scrollHeight;
}
function addSystemMessage(text) {
const messageEl = document.createElement('div');
messageEl.classList.add('message');
messageEl.classList.add('message', 'system');
messageEl.textContent = text;
conversationEl.appendChild(messageEl);
conversationEl.scrollTop = conversationEl.scrollHeight;
}
// Send text for audio generation
function sendTextForGeneration() {
const text = textInputEl.value.trim();
const speaker = parseInt(speakerSelectEl.value);
if (!text) return;
addUserMessage(text);
textInputEl.value = '';
const request = {
action: 'generate',
text: text,
speaker: speaker
};
ws.send(JSON.stringify(request));
}
// Audio recording functions
async function setupRecording() {
try {
const stream = await navigator.mediaDevices.getUserMedia({ audio: true });
mediaRecorder = new MediaRecorder(stream);
mediaRecorder.ondataavailable = (event) => {
if (event.data.size > 0) {
audioChunks.push(event.data);
}
};
mediaRecorder.onstop = async () => {
const audioBlob = new Blob(audioChunks, { type: 'audio/wav' });
const audioUrl = URL.createObjectURL(audioBlob);
// Add audio to conversation
addUserMessage('Recorded audio:');
const messageEl = document.createElement('div');
messageEl.classList.add('message', 'user');
const audioEl = document.createElement('audio');
audioEl.controls = true;
audioEl.src = audioUrl;
messageEl.appendChild(audioEl);
conversationEl.appendChild(messageEl);
// Convert to base64
const reader = new FileReader();
reader.readAsDataURL(audioBlob);
reader.onloadend = () => {
const base64Audio = reader.result;
const text = textInputEl.value.trim() || "Recorded audio";
const speaker = parseInt(speakerSelectEl.value);
// Send to server
const request = {
action: 'add_to_context',
text: text,
speaker: speaker,
audio: base64Audio
};
ws.send(JSON.stringify(request));
textInputEl.value = '';
};
audioChunks = [];
recordAudioBtn.textContent = 'Record Audio';
recordAudioBtn.classList.remove('recording');
};
console.log('Recording setup completed');
return true;
} catch (err) {
console.error('Error setting up recording:', err);
addSystemMessage(`Microphone access error: ${err.message}`);
return false;
}
}
function toggleRecording() {
if (isRecording) {
mediaRecorder.stop();
isRecording = false;
} else {
if (!mediaRecorder) {
setupRecording().then(success => {
if (success) startRecording();
});
} else {
startRecording();
}
}
}
function startRecording() {
audioChunks = [];
mediaRecorder.start();
isRecording = true;
recordAudioBtn.textContent = 'Stop Recording';
recordAudioBtn.classList.add('recording');
}
// Event listeners
sendTextBtn.addEventListener('click', sendTextForGeneration);
textInputEl.addEventListener('keypress', (e) => {
if (e.key === 'Enter') sendTextForGeneration();
});
recordAudioBtn.addEventListener('click', toggleRecording);
clearContextBtn.addEventListener('click', () => {
ws.send(JSON.stringify({
action: 'clear_context'
}));
});
// Initialize
window.addEventListener('load', () => {
connectWebSocket();
setupRecording();
});
</script>
</body>
</html>

View File

@@ -5,6 +5,8 @@ import asyncio
import torch
import torchaudio
import numpy as np
import io
import whisperx
from io import BytesIO
from typing import List, Dict, Any, Optional
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
@@ -13,6 +15,7 @@ from pydantic import BaseModel
from generator import load_csm_1b, Segment
import uvicorn
import time
import gc
from collections import deque
# Select device
@@ -25,6 +28,12 @@ print(f"Using device: {device}")
# Initialize the model
generator = load_csm_1b(device=device)
# Initialize WhisperX for ASR
print("Loading WhisperX model...")
# Use a smaller model for faster response times
asr_model = whisperx.load_model("medium", device, compute_type="float16")
print("WhisperX model loaded!")
app = FastAPI()
# Add CORS middleware to allow cross-origin requests
@@ -93,6 +102,68 @@ async def encode_audio_data(audio_tensor: torch.Tensor) -> str:
return f"data:audio/wav;base64,{audio_base64}"
async def transcribe_audio(audio_tensor: torch.Tensor) -> str:
"""Transcribe audio using WhisperX"""
try:
# Save the tensor to a temporary file
temp_file = BytesIO()
torchaudio.save(temp_file, audio_tensor.unsqueeze(0).cpu(), generator.sample_rate, format="wav")
temp_file.seek(0)
# Create a temporary file on disk (WhisperX requires a file path)
temp_path = "temp_audio.wav"
with open(temp_path, "wb") as f:
f.write(temp_file.read())
# Load and transcribe the audio
audio = whisperx.load_audio(temp_path)
result = asr_model.transcribe(audio, batch_size=16)
# Clean up
os.remove(temp_path)
# Get the transcription text
if result["segments"] and len(result["segments"]) > 0:
# Combine all segments
transcription = " ".join([segment["text"] for segment in result["segments"]])
print(f"Transcription: {transcription}")
return transcription.strip()
else:
return ""
except Exception as e:
print(f"Error in transcription: {str(e)}")
return ""
async def generate_response(text: str, conversation_history: List[Segment]) -> str:
"""Generate a contextual response based on the transcribed text"""
# Simple response logic - can be replaced with a more sophisticated LLM in the future
responses = {
"hello": "Hello there! How are you doing today?",
"how are you": "I'm doing well, thanks for asking! How about you?",
"what is your name": "I'm Sesame, your voice assistant. How can I help you?",
"bye": "Goodbye! It was nice chatting with you.",
"thank you": "You're welcome! Is there anything else I can help with?",
"weather": "I don't have real-time weather data, but I hope it's nice where you are!",
"help": "I can chat with you using natural voice. Just speak normally and I'll respond.",
}
text_lower = text.lower()
# Check for matching keywords
for key, response in responses.items():
if key in text_lower:
return response
# Default responses based on text length
if not text:
return "I didn't catch that. Could you please repeat?"
elif len(text) < 10:
return "Thanks for your message. Could you elaborate a bit more?"
else:
return f"I understand you said '{text}'. That's interesting! Can you tell me more about that?"
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
await manager.connect(websocket)
@@ -220,30 +291,55 @@ async def websocket_endpoint(websocket: WebSocket):
# User has stopped talking - process the collected audio
full_audio = torch.cat(streaming_buffer, dim=0)
# Process with speech-to-text (you would need to implement this)
# For now, just use a placeholder text
text = f"User audio from speaker {speaker_id}"
# Process with WhisperX speech-to-text
transcribed_text = await transcribe_audio(full_audio)
print(f"Detected end of speech, processing {len(streaming_buffer)} chunks")
# Log the transcription
print(f"Transcribed text: '{transcribed_text}'")
# Add to conversation context
context_segments.append(Segment(text=text, speaker=speaker_id, audio=full_audio))
# Generate response
response_text = "This is a response to what you just said"
audio_tensor = generator.generate(
text=response_text,
speaker=1 if speaker_id == 0 else 0, # Use opposite speaker
context=context_segments,
max_audio_length_ms=10_000,
)
# Convert audio to base64 and send back to client
audio_base64 = await encode_audio_data(audio_tensor)
await websocket.send_json({
"type": "audio_response",
"audio": audio_base64
})
if transcribed_text:
user_segment = Segment(text=transcribed_text, speaker=speaker_id, audio=full_audio)
context_segments.append(user_segment)
# Generate a contextual response
response_text = await generate_response(transcribed_text, context_segments)
# Send the transcribed text to client
await websocket.send_json({
"type": "transcription",
"text": transcribed_text
})
# Generate audio for the response
audio_tensor = generator.generate(
text=response_text,
speaker=1 if speaker_id == 0 else 0, # Use opposite speaker
context=context_segments,
max_audio_length_ms=10_000,
)
# Add response to context
ai_segment = Segment(
text=response_text,
speaker=1 if speaker_id == 0 else 0,
audio=audio_tensor
)
context_segments.append(ai_segment)
# Convert audio to base64 and send back to client
audio_base64 = await encode_audio_data(audio_tensor)
await websocket.send_json({
"type": "audio_response",
"text": response_text,
"audio": audio_base64
})
else:
# If transcription failed, send a generic response
await websocket.send_json({
"type": "error",
"message": "Sorry, I couldn't understand what you said. Could you try again?"
})
# Clear buffer and reset silence detection
streaming_buffer = []
@@ -256,8 +352,19 @@ async def websocket_endpoint(websocket: WebSocket):
elif len(streaming_buffer) >= 30: # ~6 seconds of audio at 5 chunks/sec
print("Buffer limit reached, processing audio")
full_audio = torch.cat(streaming_buffer, dim=0)
text = f"Continued speech from speaker {speaker_id}"
context_segments.append(Segment(text=text, speaker=speaker_id, audio=full_audio))
# Process with WhisperX speech-to-text
transcribed_text = await transcribe_audio(full_audio)
if transcribed_text:
context_segments.append(Segment(text=transcribed_text, speaker=speaker_id, audio=full_audio))
# Send the transcribed text to client
await websocket.send_json({
"type": "transcription",
"text": transcribed_text + " (processing continued speech...)"
})
streaming_buffer = []
except Exception as e:
@@ -269,11 +376,21 @@ async def websocket_endpoint(websocket: WebSocket):
elif action == "stop_streaming":
is_streaming = False
if streaming_buffer:
if streaming_buffer and len(streaming_buffer) > 5: # Only process if there's meaningful audio
# Process any remaining audio in the buffer
full_audio = torch.cat(streaming_buffer, dim=0)
text = f"Final streaming audio from speaker {request.get('speaker', 0)}"
context_segments.append(Segment(text=text, speaker=request.get("speaker", 0), audio=full_audio))
# Process with WhisperX speech-to-text
transcribed_text = await transcribe_audio(full_audio)
if transcribed_text:
context_segments.append(Segment(text=transcribed_text, speaker=request.get("speaker", 0), audio=full_audio))
# Send the transcribed text to client
await websocket.send_json({
"type": "transcription",
"text": transcribed_text
})
streaming_buffer = []
await websocket.send_json({
@@ -286,12 +403,15 @@ async def websocket_endpoint(websocket: WebSocket):
print("Client disconnected")
except Exception as e:
print(f"Error: {str(e)}")
await websocket.send_json({
"type": "error",
"message": str(e)
})
try:
await websocket.send_json({
"type": "error",
"message": str(e)
})
except:
pass
manager.disconnect(websocket)
if __name__ == "__main__":
uvicorn.run(app, host="localhost", port=8000)
uvicorn.run(app, host="0.0.0.0", port=8000)