Backend Server Update
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -134,3 +134,4 @@ dist
|
|||||||
.yarn/build-state.yml
|
.yarn/build-state.yml
|
||||||
.yarn/install-state.gz
|
.yarn/install-state.gz
|
||||||
.pnp.*
|
.pnp.*
|
||||||
|
Backend/test.py
|
||||||
|
|||||||
@@ -10,60 +10,113 @@
|
|||||||
max-width: 800px;
|
max-width: 800px;
|
||||||
margin: 0 auto;
|
margin: 0 auto;
|
||||||
padding: 20px;
|
padding: 20px;
|
||||||
|
background-color: #f9f9f9;
|
||||||
}
|
}
|
||||||
.conversation {
|
.conversation {
|
||||||
border: 1px solid #ccc;
|
border: 1px solid #ddd;
|
||||||
border-radius: 8px;
|
border-radius: 12px;
|
||||||
padding: 15px;
|
padding: 20px;
|
||||||
height: 300px;
|
height: 400px;
|
||||||
overflow-y: auto;
|
overflow-y: auto;
|
||||||
margin-bottom: 15px;
|
margin-bottom: 20px;
|
||||||
|
background-color: white;
|
||||||
|
box-shadow: 0 2px 10px rgba(0,0,0,0.05);
|
||||||
}
|
}
|
||||||
.message {
|
.message {
|
||||||
margin-bottom: 10px;
|
margin-bottom: 15px;
|
||||||
padding: 8px;
|
padding: 12px;
|
||||||
border-radius: 8px;
|
border-radius: 12px;
|
||||||
|
max-width: 80%;
|
||||||
|
line-height: 1.4;
|
||||||
}
|
}
|
||||||
.user {
|
.user {
|
||||||
background-color: #e3f2fd;
|
background-color: #e3f2fd;
|
||||||
text-align: right;
|
text-align: right;
|
||||||
|
margin-left: auto;
|
||||||
|
border-bottom-right-radius: 4px;
|
||||||
}
|
}
|
||||||
.ai {
|
.ai {
|
||||||
background-color: #f1f1f1;
|
background-color: #f1f1f1;
|
||||||
|
margin-right: auto;
|
||||||
|
border-bottom-left-radius: 4px;
|
||||||
|
}
|
||||||
|
.system {
|
||||||
|
background-color: #f8f9fa;
|
||||||
|
font-style: italic;
|
||||||
|
text-align: center;
|
||||||
|
font-size: 0.9em;
|
||||||
|
color: #666;
|
||||||
|
padding: 8px;
|
||||||
|
margin: 10px auto;
|
||||||
|
max-width: 90%;
|
||||||
}
|
}
|
||||||
.controls {
|
.controls {
|
||||||
display: flex;
|
display: flex;
|
||||||
flex-direction: column;
|
gap: 15px;
|
||||||
gap: 10px;
|
justify-content: center;
|
||||||
}
|
align-items: center;
|
||||||
.input-row {
|
|
||||||
display: flex;
|
|
||||||
gap: 10px;
|
|
||||||
}
|
|
||||||
input[type="text"] {
|
|
||||||
flex-grow: 1;
|
|
||||||
padding: 8px;
|
|
||||||
border-radius: 4px;
|
|
||||||
border: 1px solid #ccc;
|
|
||||||
}
|
}
|
||||||
button {
|
button {
|
||||||
padding: 8px 16px;
|
padding: 12px 24px;
|
||||||
border-radius: 4px;
|
border-radius: 24px;
|
||||||
border: none;
|
border: none;
|
||||||
background-color: #4CAF50;
|
background-color: #4CAF50;
|
||||||
color: white;
|
color: white;
|
||||||
cursor: pointer;
|
cursor: pointer;
|
||||||
|
font-weight: bold;
|
||||||
|
transition: all 0.2s ease;
|
||||||
|
box-shadow: 0 2px 5px rgba(0,0,0,0.1);
|
||||||
}
|
}
|
||||||
button:hover {
|
button:hover {
|
||||||
background-color: #45a049;
|
background-color: #45a049;
|
||||||
|
box-shadow: 0 4px 8px rgba(0,0,0,0.15);
|
||||||
}
|
}
|
||||||
.recording {
|
.recording {
|
||||||
background-color: #f44336;
|
background-color: #f44336;
|
||||||
|
animation: pulse 1.5s infinite;
|
||||||
|
}
|
||||||
|
.processing {
|
||||||
|
background-color: #FFA500;
|
||||||
}
|
}
|
||||||
select {
|
select {
|
||||||
padding: 8px;
|
padding: 10px;
|
||||||
border-radius: 4px;
|
border-radius: 24px;
|
||||||
border: 1px solid #ccc;
|
border: 1px solid #ddd;
|
||||||
|
background-color: white;
|
||||||
|
}
|
||||||
|
.transcript {
|
||||||
|
font-style: italic;
|
||||||
|
color: #666;
|
||||||
|
margin-top: 5px;
|
||||||
|
}
|
||||||
|
@keyframes pulse {
|
||||||
|
0% { opacity: 1; }
|
||||||
|
50% { opacity: 0.7; }
|
||||||
|
100% { opacity: 1; }
|
||||||
|
}
|
||||||
|
.status-indicator {
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
justify-content: center;
|
||||||
|
margin-top: 10px;
|
||||||
|
gap: 5px;
|
||||||
|
}
|
||||||
|
.status-dot {
|
||||||
|
width: 10px;
|
||||||
|
height: 10px;
|
||||||
|
border-radius: 50%;
|
||||||
|
background-color: #ccc;
|
||||||
|
}
|
||||||
|
.status-dot.active {
|
||||||
|
background-color: #4CAF50;
|
||||||
|
}
|
||||||
|
.status-text {
|
||||||
|
font-size: 0.9em;
|
||||||
|
color: #666;
|
||||||
|
}
|
||||||
|
audio {
|
||||||
|
width: 100%;
|
||||||
|
margin-top: 5px;
|
||||||
}
|
}
|
||||||
</style>
|
</style>
|
||||||
</head>
|
</head>
|
||||||
@@ -72,30 +125,25 @@
|
|||||||
<div class="conversation" id="conversation"></div>
|
<div class="conversation" id="conversation"></div>
|
||||||
|
|
||||||
<div class="controls">
|
<div class="controls">
|
||||||
<div class="input-row">
|
|
||||||
<input type="text" id="textInput" placeholder="Type your message...">
|
|
||||||
<select id="speakerSelect">
|
<select id="speakerSelect">
|
||||||
<option value="0">Speaker 0</option>
|
<option value="0">Speaker 0</option>
|
||||||
<option value="1">Speaker 1</option>
|
<option value="1">Speaker 1</option>
|
||||||
</select>
|
</select>
|
||||||
<button id="sendText">Send</button>
|
<button id="streamButton">Start Conversation</button>
|
||||||
|
<button id="clearButton">Clear Chat</button>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<div class="input-row">
|
<div class="status-indicator">
|
||||||
<button id="recordAudio">Record Audio</button>
|
<div class="status-dot" id="statusDot"></div>
|
||||||
<button id="clearContext">Clear Context</button>
|
<div class="status-text" id="statusText">Not connected</div>
|
||||||
</div>
|
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<script>
|
<script>
|
||||||
|
// Variables
|
||||||
let ws;
|
let ws;
|
||||||
let mediaRecorder;
|
|
||||||
let audioChunks = [];
|
|
||||||
let isRecording = false;
|
|
||||||
let audioContext;
|
let audioContext;
|
||||||
let streamProcessor;
|
let streamProcessor;
|
||||||
let isStreaming = false;
|
let isStreaming = false;
|
||||||
let streamButton;
|
|
||||||
let isSpeaking = false;
|
let isSpeaking = false;
|
||||||
let silenceTimer = null;
|
let silenceTimer = null;
|
||||||
let energyWindow = [];
|
let energyWindow = [];
|
||||||
@@ -105,24 +153,20 @@
|
|||||||
|
|
||||||
// DOM elements
|
// DOM elements
|
||||||
const conversationEl = document.getElementById('conversation');
|
const conversationEl = document.getElementById('conversation');
|
||||||
const textInputEl = document.getElementById('textInput');
|
|
||||||
const speakerSelectEl = document.getElementById('speakerSelect');
|
const speakerSelectEl = document.getElementById('speakerSelect');
|
||||||
const sendTextBtn = document.getElementById('sendText');
|
const streamButton = document.getElementById('streamButton');
|
||||||
const recordAudioBtn = document.getElementById('recordAudio');
|
const clearButton = document.getElementById('clearButton');
|
||||||
const clearContextBtn = document.getElementById('clearContext');
|
const statusDot = document.getElementById('statusDot');
|
||||||
|
const statusText = document.getElementById('statusText');
|
||||||
|
|
||||||
// Add streaming button to the input row
|
// Initialize on page load
|
||||||
window.addEventListener('load', () => {
|
window.addEventListener('load', () => {
|
||||||
const inputRow = document.querySelector('.input-row:nth-child(2)');
|
|
||||||
streamButton = document.createElement('button');
|
|
||||||
streamButton.id = 'streamAudio';
|
|
||||||
streamButton.textContent = 'Start Streaming';
|
|
||||||
streamButton.addEventListener('click', toggleStreaming);
|
|
||||||
inputRow.appendChild(streamButton);
|
|
||||||
|
|
||||||
connectWebSocket();
|
connectWebSocket();
|
||||||
setupRecording();
|
|
||||||
setupAudioContext();
|
setupAudioContext();
|
||||||
|
|
||||||
|
// Event listeners
|
||||||
|
streamButton.addEventListener('click', toggleStreaming);
|
||||||
|
clearButton.addEventListener('click', clearConversation);
|
||||||
});
|
});
|
||||||
|
|
||||||
// Setup audio context for streaming
|
// Setup audio context for streaming
|
||||||
@@ -136,8 +180,68 @@
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Toggle audio streaming
|
// Connect to WebSocket server
|
||||||
async function toggleStreaming() {
|
function connectWebSocket() {
|
||||||
|
const wsProtocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:';
|
||||||
|
const wsUrl = `${wsProtocol}//${window.location.hostname}:8000/ws`;
|
||||||
|
|
||||||
|
ws = new WebSocket(wsUrl);
|
||||||
|
|
||||||
|
ws.onopen = () => {
|
||||||
|
console.log('WebSocket connected');
|
||||||
|
statusDot.classList.add('active');
|
||||||
|
statusText.textContent = 'Connected';
|
||||||
|
addSystemMessage('Connected to server');
|
||||||
|
};
|
||||||
|
|
||||||
|
ws.onmessage = (event) => {
|
||||||
|
const response = JSON.parse(event.data);
|
||||||
|
console.log('Received:', response);
|
||||||
|
|
||||||
|
if (response.type === 'audio_response') {
|
||||||
|
// Play audio response
|
||||||
|
const audio = new Audio(response.audio);
|
||||||
|
audio.play();
|
||||||
|
|
||||||
|
// Add message to conversation
|
||||||
|
addAIMessage(response.text || 'AI response', response.audio);
|
||||||
|
|
||||||
|
// Reset to speaking state after AI response
|
||||||
|
if (isStreaming) {
|
||||||
|
streamButton.textContent = 'Listening...';
|
||||||
|
streamButton.style.backgroundColor = '#f44336'; // Back to red
|
||||||
|
streamButton.classList.add('recording');
|
||||||
|
isSpeaking = false; // Reset speaking state
|
||||||
|
}
|
||||||
|
} else if (response.type === 'error') {
|
||||||
|
addSystemMessage(`Error: ${response.message}`);
|
||||||
|
} else if (response.type === 'context_updated') {
|
||||||
|
addSystemMessage(response.message);
|
||||||
|
} else if (response.type === 'streaming_status') {
|
||||||
|
addSystemMessage(`Streaming ${response.status}`);
|
||||||
|
} else if (response.type === 'transcription') {
|
||||||
|
addUserTranscription(response.text);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
ws.onclose = () => {
|
||||||
|
console.log('WebSocket disconnected');
|
||||||
|
statusDot.classList.remove('active');
|
||||||
|
statusText.textContent = 'Disconnected';
|
||||||
|
addSystemMessage('Disconnected from server. Reconnecting...');
|
||||||
|
setTimeout(connectWebSocket, 3000);
|
||||||
|
};
|
||||||
|
|
||||||
|
ws.onerror = (error) => {
|
||||||
|
console.error('WebSocket error:', error);
|
||||||
|
statusDot.classList.remove('active');
|
||||||
|
statusText.textContent = 'Error';
|
||||||
|
addSystemMessage('Connection error');
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
// Toggle streaming
|
||||||
|
function toggleStreaming() {
|
||||||
if (isStreaming) {
|
if (isStreaming) {
|
||||||
stopStreaming();
|
stopStreaming();
|
||||||
} else {
|
} else {
|
||||||
@@ -145,7 +249,7 @@
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start audio streaming with silence detection
|
// Start streaming
|
||||||
async function startStreaming() {
|
async function startStreaming() {
|
||||||
try {
|
try {
|
||||||
const stream = await navigator.mediaDevices.getUserMedia({ audio: true });
|
const stream = await navigator.mediaDevices.getUserMedia({ audio: true });
|
||||||
@@ -155,7 +259,7 @@
|
|||||||
isSpeaking = false;
|
isSpeaking = false;
|
||||||
energyWindow = [];
|
energyWindow = [];
|
||||||
|
|
||||||
streamButton.textContent = 'Speaking...';
|
streamButton.textContent = 'Listening...';
|
||||||
streamButton.classList.add('recording');
|
streamButton.classList.add('recording');
|
||||||
|
|
||||||
// Create audio processor node
|
// Create audio processor node
|
||||||
@@ -186,13 +290,13 @@
|
|||||||
source.connect(streamProcessor);
|
source.connect(streamProcessor);
|
||||||
streamProcessor.connect(audioContext.destination);
|
streamProcessor.connect(audioContext.destination);
|
||||||
|
|
||||||
addSystemMessage('Audio streaming started - speak naturally and pause when finished');
|
addSystemMessage('Listening - speak naturally and pause when finished');
|
||||||
|
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
console.error('Error starting audio stream:', err);
|
console.error('Error starting audio stream:', err);
|
||||||
addSystemMessage(`Streaming error: ${err.message}`);
|
addSystemMessage(`Microphone error: ${err.message}`);
|
||||||
isStreaming = false;
|
isStreaming = false;
|
||||||
streamButton.textContent = 'Start Streaming';
|
streamButton.textContent = 'Start Conversation';
|
||||||
streamButton.classList.remove('recording');
|
streamButton.classList.remove('recording');
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -228,15 +332,17 @@
|
|||||||
silenceTimer = setTimeout(() => {
|
silenceTimer = setTimeout(() => {
|
||||||
// Silence persisted long enough
|
// Silence persisted long enough
|
||||||
streamButton.textContent = 'Processing...';
|
streamButton.textContent = 'Processing...';
|
||||||
streamButton.style.backgroundColor = '#FFA500'; // Orange
|
streamButton.classList.remove('recording');
|
||||||
|
streamButton.classList.add('processing');
|
||||||
addSystemMessage('Detected pause in speech, processing response...');
|
addSystemMessage('Detected pause in speech, processing response...');
|
||||||
}, CLIENT_SILENCE_DURATION_MS);
|
}, CLIENT_SILENCE_DURATION_MS);
|
||||||
}
|
}
|
||||||
} else if (!isSpeaking && !isSilent) {
|
} else if (!isSpeaking && !isSilent) {
|
||||||
// Transition from silence to speaking
|
// Transition from silence to speaking
|
||||||
isSpeaking = true;
|
isSpeaking = true;
|
||||||
streamButton.textContent = 'Speaking...';
|
streamButton.textContent = 'Listening...';
|
||||||
streamButton.style.backgroundColor = '#f44336'; // Red
|
streamButton.classList.add('recording');
|
||||||
|
streamButton.classList.remove('processing');
|
||||||
|
|
||||||
// Clear any pending silence timer
|
// Clear any pending silence timer
|
||||||
if (silenceTimer) {
|
if (silenceTimer) {
|
||||||
@@ -276,7 +382,7 @@
|
|||||||
reader.readAsDataURL(wavData);
|
reader.readAsDataURL(wavData);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Stop audio streaming
|
// Stop streaming
|
||||||
function stopStreaming() {
|
function stopStreaming() {
|
||||||
if (streamProcessor) {
|
if (streamProcessor) {
|
||||||
streamProcessor.disconnect();
|
streamProcessor.disconnect();
|
||||||
@@ -293,11 +399,11 @@
|
|||||||
isSpeaking = false;
|
isSpeaking = false;
|
||||||
energyWindow = [];
|
energyWindow = [];
|
||||||
|
|
||||||
streamButton.textContent = 'Start Streaming';
|
streamButton.textContent = 'Start Conversation';
|
||||||
streamButton.classList.remove('recording');
|
streamButton.classList.remove('recording', 'processing');
|
||||||
streamButton.style.backgroundColor = ''; // Reset to default
|
streamButton.style.backgroundColor = ''; // Reset to default
|
||||||
|
|
||||||
addSystemMessage('Audio streaming stopped');
|
addSystemMessage('Conversation paused');
|
||||||
|
|
||||||
// Send stop streaming signal to server
|
// Send stop streaming signal to server
|
||||||
ws.send(JSON.stringify({
|
ws.send(JSON.stringify({
|
||||||
@@ -306,6 +412,18 @@
|
|||||||
}));
|
}));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Clear conversation
|
||||||
|
function clearConversation() {
|
||||||
|
// Clear conversation history
|
||||||
|
ws.send(JSON.stringify({
|
||||||
|
action: 'clear_context'
|
||||||
|
}));
|
||||||
|
|
||||||
|
// Clear the UI
|
||||||
|
conversationEl.innerHTML = '';
|
||||||
|
addSystemMessage('Conversation cleared');
|
||||||
|
}
|
||||||
|
|
||||||
// Downsample audio buffer to target sample rate
|
// Downsample audio buffer to target sample rate
|
||||||
function downsampleBuffer(buffer, sampleRate, targetSampleRate) {
|
function downsampleBuffer(buffer, sampleRate, targetSampleRate) {
|
||||||
if (targetSampleRate === sampleRate) {
|
if (targetSampleRate === sampleRate) {
|
||||||
@@ -376,212 +494,49 @@
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Connect to WebSocket
|
// Message display functions
|
||||||
function connectWebSocket() {
|
function addUserTranscription(text) {
|
||||||
const wsProtocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:';
|
// Find if there's already a pending user message
|
||||||
const wsUrl = `${wsProtocol}//${window.location.hostname}:8000/ws`;
|
let pendingMessage = document.querySelector('.message.user.pending');
|
||||||
|
|
||||||
ws = new WebSocket(wsUrl);
|
if (!pendingMessage) {
|
||||||
|
// Create a new message
|
||||||
ws.onopen = () => {
|
pendingMessage = document.createElement('div');
|
||||||
console.log('WebSocket connected');
|
pendingMessage.classList.add('message', 'user', 'pending');
|
||||||
addSystemMessage('Connected to server');
|
conversationEl.appendChild(pendingMessage);
|
||||||
};
|
|
||||||
|
|
||||||
ws.onmessage = (event) => {
|
|
||||||
const response = JSON.parse(event.data);
|
|
||||||
console.log('Received:', response);
|
|
||||||
|
|
||||||
if (response.type === 'audio_response') {
|
|
||||||
// Play audio response
|
|
||||||
const audio = new Audio(response.audio);
|
|
||||||
audio.play();
|
|
||||||
|
|
||||||
// Add message to conversation
|
|
||||||
addAIMessage(response.audio);
|
|
||||||
|
|
||||||
// Reset the streaming button if we're still in streaming mode
|
|
||||||
if (isStreaming) {
|
|
||||||
streamButton.textContent = 'Speaking...';
|
|
||||||
streamButton.style.backgroundColor = '#f44336'; // Back to red
|
|
||||||
isSpeaking = false; // Reset speaking state
|
|
||||||
}
|
|
||||||
} else if (response.type === 'error') {
|
|
||||||
addSystemMessage(`Error: ${response.message}`);
|
|
||||||
} else if (response.type === 'context_updated') {
|
|
||||||
addSystemMessage(response.message);
|
|
||||||
} else if (response.type === 'streaming_status') {
|
|
||||||
addSystemMessage(`Streaming ${response.status}`);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
ws.onclose = () => {
|
|
||||||
console.log('WebSocket disconnected');
|
|
||||||
addSystemMessage('Disconnected from server. Reconnecting...');
|
|
||||||
setTimeout(connectWebSocket, 3000);
|
|
||||||
};
|
|
||||||
|
|
||||||
ws.onerror = (error) => {
|
|
||||||
console.error('WebSocket error:', error);
|
|
||||||
addSystemMessage('Connection error');
|
|
||||||
};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add message to conversation
|
pendingMessage.textContent = text;
|
||||||
function addUserMessage(text) {
|
pendingMessage.classList.remove('pending');
|
||||||
const messageEl = document.createElement('div');
|
|
||||||
messageEl.classList.add('message', 'user');
|
|
||||||
messageEl.textContent = text;
|
|
||||||
conversationEl.appendChild(messageEl);
|
|
||||||
conversationEl.scrollTop = conversationEl.scrollHeight;
|
conversationEl.scrollTop = conversationEl.scrollHeight;
|
||||||
}
|
}
|
||||||
|
|
||||||
function addAIMessage(audioSrc) {
|
function addAIMessage(text, audioSrc) {
|
||||||
const messageEl = document.createElement('div');
|
const messageEl = document.createElement('div');
|
||||||
messageEl.classList.add('message', 'ai');
|
messageEl.classList.add('message', 'ai');
|
||||||
|
|
||||||
|
if (text) {
|
||||||
|
const textDiv = document.createElement('div');
|
||||||
|
textDiv.textContent = text;
|
||||||
|
messageEl.appendChild(textDiv);
|
||||||
|
}
|
||||||
|
|
||||||
const audioEl = document.createElement('audio');
|
const audioEl = document.createElement('audio');
|
||||||
audioEl.controls = true;
|
audioEl.controls = true;
|
||||||
audioEl.src = audioSrc;
|
audioEl.src = audioSrc;
|
||||||
|
|
||||||
messageEl.appendChild(audioEl);
|
messageEl.appendChild(audioEl);
|
||||||
|
|
||||||
conversationEl.appendChild(messageEl);
|
conversationEl.appendChild(messageEl);
|
||||||
conversationEl.scrollTop = conversationEl.scrollHeight;
|
conversationEl.scrollTop = conversationEl.scrollHeight;
|
||||||
}
|
}
|
||||||
|
|
||||||
function addSystemMessage(text) {
|
function addSystemMessage(text) {
|
||||||
const messageEl = document.createElement('div');
|
const messageEl = document.createElement('div');
|
||||||
messageEl.classList.add('message');
|
messageEl.classList.add('message', 'system');
|
||||||
messageEl.textContent = text;
|
messageEl.textContent = text;
|
||||||
conversationEl.appendChild(messageEl);
|
conversationEl.appendChild(messageEl);
|
||||||
conversationEl.scrollTop = conversationEl.scrollHeight;
|
conversationEl.scrollTop = conversationEl.scrollHeight;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Send text for audio generation
|
|
||||||
function sendTextForGeneration() {
|
|
||||||
const text = textInputEl.value.trim();
|
|
||||||
const speaker = parseInt(speakerSelectEl.value);
|
|
||||||
|
|
||||||
if (!text) return;
|
|
||||||
|
|
||||||
addUserMessage(text);
|
|
||||||
textInputEl.value = '';
|
|
||||||
|
|
||||||
const request = {
|
|
||||||
action: 'generate',
|
|
||||||
text: text,
|
|
||||||
speaker: speaker
|
|
||||||
};
|
|
||||||
|
|
||||||
ws.send(JSON.stringify(request));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Audio recording functions
|
|
||||||
async function setupRecording() {
|
|
||||||
try {
|
|
||||||
const stream = await navigator.mediaDevices.getUserMedia({ audio: true });
|
|
||||||
|
|
||||||
mediaRecorder = new MediaRecorder(stream);
|
|
||||||
|
|
||||||
mediaRecorder.ondataavailable = (event) => {
|
|
||||||
if (event.data.size > 0) {
|
|
||||||
audioChunks.push(event.data);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
mediaRecorder.onstop = async () => {
|
|
||||||
const audioBlob = new Blob(audioChunks, { type: 'audio/wav' });
|
|
||||||
const audioUrl = URL.createObjectURL(audioBlob);
|
|
||||||
|
|
||||||
// Add audio to conversation
|
|
||||||
addUserMessage('Recorded audio:');
|
|
||||||
const messageEl = document.createElement('div');
|
|
||||||
messageEl.classList.add('message', 'user');
|
|
||||||
|
|
||||||
const audioEl = document.createElement('audio');
|
|
||||||
audioEl.controls = true;
|
|
||||||
audioEl.src = audioUrl;
|
|
||||||
|
|
||||||
messageEl.appendChild(audioEl);
|
|
||||||
conversationEl.appendChild(messageEl);
|
|
||||||
|
|
||||||
// Convert to base64
|
|
||||||
const reader = new FileReader();
|
|
||||||
reader.readAsDataURL(audioBlob);
|
|
||||||
reader.onloadend = () => {
|
|
||||||
const base64Audio = reader.result;
|
|
||||||
const text = textInputEl.value.trim() || "Recorded audio";
|
|
||||||
const speaker = parseInt(speakerSelectEl.value);
|
|
||||||
|
|
||||||
// Send to server
|
|
||||||
const request = {
|
|
||||||
action: 'add_to_context',
|
|
||||||
text: text,
|
|
||||||
speaker: speaker,
|
|
||||||
audio: base64Audio
|
|
||||||
};
|
|
||||||
|
|
||||||
ws.send(JSON.stringify(request));
|
|
||||||
textInputEl.value = '';
|
|
||||||
};
|
|
||||||
|
|
||||||
audioChunks = [];
|
|
||||||
recordAudioBtn.textContent = 'Record Audio';
|
|
||||||
recordAudioBtn.classList.remove('recording');
|
|
||||||
};
|
|
||||||
|
|
||||||
console.log('Recording setup completed');
|
|
||||||
return true;
|
|
||||||
} catch (err) {
|
|
||||||
console.error('Error setting up recording:', err);
|
|
||||||
addSystemMessage(`Microphone access error: ${err.message}`);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
function toggleRecording() {
|
|
||||||
if (isRecording) {
|
|
||||||
mediaRecorder.stop();
|
|
||||||
isRecording = false;
|
|
||||||
} else {
|
|
||||||
if (!mediaRecorder) {
|
|
||||||
setupRecording().then(success => {
|
|
||||||
if (success) startRecording();
|
|
||||||
});
|
|
||||||
} else {
|
|
||||||
startRecording();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
function startRecording() {
|
|
||||||
audioChunks = [];
|
|
||||||
mediaRecorder.start();
|
|
||||||
isRecording = true;
|
|
||||||
recordAudioBtn.textContent = 'Stop Recording';
|
|
||||||
recordAudioBtn.classList.add('recording');
|
|
||||||
}
|
|
||||||
|
|
||||||
// Event listeners
|
|
||||||
sendTextBtn.addEventListener('click', sendTextForGeneration);
|
|
||||||
|
|
||||||
textInputEl.addEventListener('keypress', (e) => {
|
|
||||||
if (e.key === 'Enter') sendTextForGeneration();
|
|
||||||
});
|
|
||||||
|
|
||||||
recordAudioBtn.addEventListener('click', toggleRecording);
|
|
||||||
|
|
||||||
clearContextBtn.addEventListener('click', () => {
|
|
||||||
ws.send(JSON.stringify({
|
|
||||||
action: 'clear_context'
|
|
||||||
}));
|
|
||||||
});
|
|
||||||
|
|
||||||
// Initialize
|
|
||||||
window.addEventListener('load', () => {
|
|
||||||
connectWebSocket();
|
|
||||||
setupRecording();
|
|
||||||
});
|
|
||||||
</script>
|
</script>
|
||||||
</body>
|
</body>
|
||||||
</html>
|
</html>
|
||||||
@@ -5,6 +5,8 @@ import asyncio
|
|||||||
import torch
|
import torch
|
||||||
import torchaudio
|
import torchaudio
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import io
|
||||||
|
import whisperx
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from typing import List, Dict, Any, Optional
|
from typing import List, Dict, Any, Optional
|
||||||
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
||||||
@@ -13,6 +15,7 @@ from pydantic import BaseModel
|
|||||||
from generator import load_csm_1b, Segment
|
from generator import load_csm_1b, Segment
|
||||||
import uvicorn
|
import uvicorn
|
||||||
import time
|
import time
|
||||||
|
import gc
|
||||||
from collections import deque
|
from collections import deque
|
||||||
|
|
||||||
# Select device
|
# Select device
|
||||||
@@ -25,6 +28,12 @@ print(f"Using device: {device}")
|
|||||||
# Initialize the model
|
# Initialize the model
|
||||||
generator = load_csm_1b(device=device)
|
generator = load_csm_1b(device=device)
|
||||||
|
|
||||||
|
# Initialize WhisperX for ASR
|
||||||
|
print("Loading WhisperX model...")
|
||||||
|
# Use a smaller model for faster response times
|
||||||
|
asr_model = whisperx.load_model("medium", device, compute_type="float16")
|
||||||
|
print("WhisperX model loaded!")
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
||||||
# Add CORS middleware to allow cross-origin requests
|
# Add CORS middleware to allow cross-origin requests
|
||||||
@@ -93,6 +102,68 @@ async def encode_audio_data(audio_tensor: torch.Tensor) -> str:
|
|||||||
return f"data:audio/wav;base64,{audio_base64}"
|
return f"data:audio/wav;base64,{audio_base64}"
|
||||||
|
|
||||||
|
|
||||||
|
async def transcribe_audio(audio_tensor: torch.Tensor) -> str:
|
||||||
|
"""Transcribe audio using WhisperX"""
|
||||||
|
try:
|
||||||
|
# Save the tensor to a temporary file
|
||||||
|
temp_file = BytesIO()
|
||||||
|
torchaudio.save(temp_file, audio_tensor.unsqueeze(0).cpu(), generator.sample_rate, format="wav")
|
||||||
|
temp_file.seek(0)
|
||||||
|
|
||||||
|
# Create a temporary file on disk (WhisperX requires a file path)
|
||||||
|
temp_path = "temp_audio.wav"
|
||||||
|
with open(temp_path, "wb") as f:
|
||||||
|
f.write(temp_file.read())
|
||||||
|
|
||||||
|
# Load and transcribe the audio
|
||||||
|
audio = whisperx.load_audio(temp_path)
|
||||||
|
result = asr_model.transcribe(audio, batch_size=16)
|
||||||
|
|
||||||
|
# Clean up
|
||||||
|
os.remove(temp_path)
|
||||||
|
|
||||||
|
# Get the transcription text
|
||||||
|
if result["segments"] and len(result["segments"]) > 0:
|
||||||
|
# Combine all segments
|
||||||
|
transcription = " ".join([segment["text"] for segment in result["segments"]])
|
||||||
|
print(f"Transcription: {transcription}")
|
||||||
|
return transcription.strip()
|
||||||
|
else:
|
||||||
|
return ""
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error in transcription: {str(e)}")
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
async def generate_response(text: str, conversation_history: List[Segment]) -> str:
|
||||||
|
"""Generate a contextual response based on the transcribed text"""
|
||||||
|
# Simple response logic - can be replaced with a more sophisticated LLM in the future
|
||||||
|
responses = {
|
||||||
|
"hello": "Hello there! How are you doing today?",
|
||||||
|
"how are you": "I'm doing well, thanks for asking! How about you?",
|
||||||
|
"what is your name": "I'm Sesame, your voice assistant. How can I help you?",
|
||||||
|
"bye": "Goodbye! It was nice chatting with you.",
|
||||||
|
"thank you": "You're welcome! Is there anything else I can help with?",
|
||||||
|
"weather": "I don't have real-time weather data, but I hope it's nice where you are!",
|
||||||
|
"help": "I can chat with you using natural voice. Just speak normally and I'll respond.",
|
||||||
|
}
|
||||||
|
|
||||||
|
text_lower = text.lower()
|
||||||
|
|
||||||
|
# Check for matching keywords
|
||||||
|
for key, response in responses.items():
|
||||||
|
if key in text_lower:
|
||||||
|
return response
|
||||||
|
|
||||||
|
# Default responses based on text length
|
||||||
|
if not text:
|
||||||
|
return "I didn't catch that. Could you please repeat?"
|
||||||
|
elif len(text) < 10:
|
||||||
|
return "Thanks for your message. Could you elaborate a bit more?"
|
||||||
|
else:
|
||||||
|
return f"I understand you said '{text}'. That's interesting! Can you tell me more about that?"
|
||||||
|
|
||||||
|
|
||||||
@app.websocket("/ws")
|
@app.websocket("/ws")
|
||||||
async def websocket_endpoint(websocket: WebSocket):
|
async def websocket_endpoint(websocket: WebSocket):
|
||||||
await manager.connect(websocket)
|
await manager.connect(websocket)
|
||||||
@@ -220,17 +291,27 @@ async def websocket_endpoint(websocket: WebSocket):
|
|||||||
# User has stopped talking - process the collected audio
|
# User has stopped talking - process the collected audio
|
||||||
full_audio = torch.cat(streaming_buffer, dim=0)
|
full_audio = torch.cat(streaming_buffer, dim=0)
|
||||||
|
|
||||||
# Process with speech-to-text (you would need to implement this)
|
# Process with WhisperX speech-to-text
|
||||||
# For now, just use a placeholder text
|
transcribed_text = await transcribe_audio(full_audio)
|
||||||
text = f"User audio from speaker {speaker_id}"
|
|
||||||
|
|
||||||
print(f"Detected end of speech, processing {len(streaming_buffer)} chunks")
|
# Log the transcription
|
||||||
|
print(f"Transcribed text: '{transcribed_text}'")
|
||||||
|
|
||||||
# Add to conversation context
|
# Add to conversation context
|
||||||
context_segments.append(Segment(text=text, speaker=speaker_id, audio=full_audio))
|
if transcribed_text:
|
||||||
|
user_segment = Segment(text=transcribed_text, speaker=speaker_id, audio=full_audio)
|
||||||
|
context_segments.append(user_segment)
|
||||||
|
|
||||||
# Generate response
|
# Generate a contextual response
|
||||||
response_text = "This is a response to what you just said"
|
response_text = await generate_response(transcribed_text, context_segments)
|
||||||
|
|
||||||
|
# Send the transcribed text to client
|
||||||
|
await websocket.send_json({
|
||||||
|
"type": "transcription",
|
||||||
|
"text": transcribed_text
|
||||||
|
})
|
||||||
|
|
||||||
|
# Generate audio for the response
|
||||||
audio_tensor = generator.generate(
|
audio_tensor = generator.generate(
|
||||||
text=response_text,
|
text=response_text,
|
||||||
speaker=1 if speaker_id == 0 else 0, # Use opposite speaker
|
speaker=1 if speaker_id == 0 else 0, # Use opposite speaker
|
||||||
@@ -238,12 +319,27 @@ async def websocket_endpoint(websocket: WebSocket):
|
|||||||
max_audio_length_ms=10_000,
|
max_audio_length_ms=10_000,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Add response to context
|
||||||
|
ai_segment = Segment(
|
||||||
|
text=response_text,
|
||||||
|
speaker=1 if speaker_id == 0 else 0,
|
||||||
|
audio=audio_tensor
|
||||||
|
)
|
||||||
|
context_segments.append(ai_segment)
|
||||||
|
|
||||||
# Convert audio to base64 and send back to client
|
# Convert audio to base64 and send back to client
|
||||||
audio_base64 = await encode_audio_data(audio_tensor)
|
audio_base64 = await encode_audio_data(audio_tensor)
|
||||||
await websocket.send_json({
|
await websocket.send_json({
|
||||||
"type": "audio_response",
|
"type": "audio_response",
|
||||||
|
"text": response_text,
|
||||||
"audio": audio_base64
|
"audio": audio_base64
|
||||||
})
|
})
|
||||||
|
else:
|
||||||
|
# If transcription failed, send a generic response
|
||||||
|
await websocket.send_json({
|
||||||
|
"type": "error",
|
||||||
|
"message": "Sorry, I couldn't understand what you said. Could you try again?"
|
||||||
|
})
|
||||||
|
|
||||||
# Clear buffer and reset silence detection
|
# Clear buffer and reset silence detection
|
||||||
streaming_buffer = []
|
streaming_buffer = []
|
||||||
@@ -256,8 +352,19 @@ async def websocket_endpoint(websocket: WebSocket):
|
|||||||
elif len(streaming_buffer) >= 30: # ~6 seconds of audio at 5 chunks/sec
|
elif len(streaming_buffer) >= 30: # ~6 seconds of audio at 5 chunks/sec
|
||||||
print("Buffer limit reached, processing audio")
|
print("Buffer limit reached, processing audio")
|
||||||
full_audio = torch.cat(streaming_buffer, dim=0)
|
full_audio = torch.cat(streaming_buffer, dim=0)
|
||||||
text = f"Continued speech from speaker {speaker_id}"
|
|
||||||
context_segments.append(Segment(text=text, speaker=speaker_id, audio=full_audio))
|
# Process with WhisperX speech-to-text
|
||||||
|
transcribed_text = await transcribe_audio(full_audio)
|
||||||
|
|
||||||
|
if transcribed_text:
|
||||||
|
context_segments.append(Segment(text=transcribed_text, speaker=speaker_id, audio=full_audio))
|
||||||
|
|
||||||
|
# Send the transcribed text to client
|
||||||
|
await websocket.send_json({
|
||||||
|
"type": "transcription",
|
||||||
|
"text": transcribed_text + " (processing continued speech...)"
|
||||||
|
})
|
||||||
|
|
||||||
streaming_buffer = []
|
streaming_buffer = []
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -269,11 +376,21 @@ async def websocket_endpoint(websocket: WebSocket):
|
|||||||
|
|
||||||
elif action == "stop_streaming":
|
elif action == "stop_streaming":
|
||||||
is_streaming = False
|
is_streaming = False
|
||||||
if streaming_buffer:
|
if streaming_buffer and len(streaming_buffer) > 5: # Only process if there's meaningful audio
|
||||||
# Process any remaining audio in the buffer
|
# Process any remaining audio in the buffer
|
||||||
full_audio = torch.cat(streaming_buffer, dim=0)
|
full_audio = torch.cat(streaming_buffer, dim=0)
|
||||||
text = f"Final streaming audio from speaker {request.get('speaker', 0)}"
|
|
||||||
context_segments.append(Segment(text=text, speaker=request.get("speaker", 0), audio=full_audio))
|
# Process with WhisperX speech-to-text
|
||||||
|
transcribed_text = await transcribe_audio(full_audio)
|
||||||
|
|
||||||
|
if transcribed_text:
|
||||||
|
context_segments.append(Segment(text=transcribed_text, speaker=request.get("speaker", 0), audio=full_audio))
|
||||||
|
|
||||||
|
# Send the transcribed text to client
|
||||||
|
await websocket.send_json({
|
||||||
|
"type": "transcription",
|
||||||
|
"text": transcribed_text
|
||||||
|
})
|
||||||
|
|
||||||
streaming_buffer = []
|
streaming_buffer = []
|
||||||
await websocket.send_json({
|
await websocket.send_json({
|
||||||
@@ -286,12 +403,15 @@ async def websocket_endpoint(websocket: WebSocket):
|
|||||||
print("Client disconnected")
|
print("Client disconnected")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error: {str(e)}")
|
print(f"Error: {str(e)}")
|
||||||
|
try:
|
||||||
await websocket.send_json({
|
await websocket.send_json({
|
||||||
"type": "error",
|
"type": "error",
|
||||||
"message": str(e)
|
"message": str(e)
|
||||||
})
|
})
|
||||||
|
except:
|
||||||
|
pass
|
||||||
manager.disconnect(websocket)
|
manager.disconnect(websocket)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
uvicorn.run(app, host="localhost", port=8000)
|
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||||
Reference in New Issue
Block a user