Demo Update 21
This commit is contained in:
@@ -1,86 +1,225 @@
|
|||||||
|
|
||||||
<!DOCTYPE html>
|
<!DOCTYPE html>
|
||||||
<html lang="en">
|
<html lang="en">
|
||||||
<head>
|
<head>
|
||||||
<meta charset="UTF-8">
|
<meta charset="UTF-8">
|
||||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||||
<title>Audio Conversation Bot</title>
|
<title>Voice Assistant - CSM & Whisper</title>
|
||||||
<script src="https://cdn.socket.io/4.6.0/socket.io.min.js"></script>
|
<script src="https://cdn.socket.io/4.6.0/socket.io.min.js"></script>
|
||||||
<style>
|
<style>
|
||||||
body {
|
body {
|
||||||
font-family: Arial, sans-serif;
|
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
|
||||||
max-width: 800px;
|
max-width: 800px;
|
||||||
margin: 0 auto;
|
margin: 0 auto;
|
||||||
padding: 20px;
|
padding: 20px;
|
||||||
|
background-color: #f5f7fa;
|
||||||
|
color: #333;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
h1 {
|
||||||
|
color: #2c3e50;
|
||||||
|
text-align: center;
|
||||||
|
margin-bottom: 30px;
|
||||||
|
}
|
||||||
|
|
||||||
#conversation {
|
#conversation {
|
||||||
height: 400px;
|
height: 400px;
|
||||||
border: 1px solid #ccc;
|
border: 1px solid #ddd;
|
||||||
padding: 15px;
|
border-radius: 10px;
|
||||||
|
padding: 20px;
|
||||||
margin-bottom: 20px;
|
margin-bottom: 20px;
|
||||||
overflow-y: auto;
|
overflow-y: auto;
|
||||||
|
background-color: white;
|
||||||
|
box-shadow: 0 2px 5px rgba(0,0,0,0.1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
.message-container {
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
margin-bottom: 15px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.user-message-container {
|
||||||
|
align-items: flex-end;
|
||||||
|
}
|
||||||
|
|
||||||
|
.bot-message-container {
|
||||||
|
align-items: flex-start;
|
||||||
|
}
|
||||||
|
|
||||||
|
.message {
|
||||||
|
max-width: 80%;
|
||||||
|
padding: 12px;
|
||||||
|
border-radius: 18px;
|
||||||
|
position: relative;
|
||||||
|
word-break: break-word;
|
||||||
|
}
|
||||||
|
|
||||||
.user-message {
|
.user-message {
|
||||||
background-color: #e1f5fe;
|
background-color: #dcf8c6;
|
||||||
padding: 10px;
|
color: #000;
|
||||||
border-radius: 8px;
|
border-bottom-right-radius: 4px;
|
||||||
margin-bottom: 10px;
|
|
||||||
align-self: flex-end;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
.bot-message {
|
.bot-message {
|
||||||
background-color: #f1f1f1;
|
background-color: #f1f0f0;
|
||||||
padding: 10px;
|
color: #000;
|
||||||
border-radius: 8px;
|
border-bottom-left-radius: 4px;
|
||||||
margin-bottom: 10px;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
.message-label {
|
||||||
|
font-size: 0.8em;
|
||||||
|
margin-bottom: 4px;
|
||||||
|
color: #657786;
|
||||||
|
}
|
||||||
|
|
||||||
#controls {
|
#controls {
|
||||||
display: flex;
|
display: flex;
|
||||||
gap: 10px;
|
gap: 10px;
|
||||||
|
justify-content: center;
|
||||||
|
margin-bottom: 15px;
|
||||||
}
|
}
|
||||||
|
|
||||||
button {
|
button {
|
||||||
padding: 10px 20px;
|
padding: 12px 24px;
|
||||||
font-size: 16px;
|
font-size: 16px;
|
||||||
cursor: pointer;
|
cursor: pointer;
|
||||||
|
border-radius: 50px;
|
||||||
|
border: none;
|
||||||
|
outline: none;
|
||||||
|
transition: all 0.3s ease;
|
||||||
}
|
}
|
||||||
|
|
||||||
#recordButton {
|
#recordButton {
|
||||||
background-color: #4CAF50;
|
background-color: #4CAF50;
|
||||||
color: white;
|
color: white;
|
||||||
border: none;
|
width: 200px;
|
||||||
border-radius: 4px;
|
box-shadow: 0 4px 8px rgba(76, 175, 80, 0.3);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#recordButton:hover {
|
||||||
|
background-color: #45a049;
|
||||||
|
transform: translateY(-2px);
|
||||||
|
}
|
||||||
|
|
||||||
#recordButton.recording {
|
#recordButton.recording {
|
||||||
background-color: #f44336;
|
background-color: #f44336;
|
||||||
|
animation: pulse 1.5s infinite;
|
||||||
|
box-shadow: 0 4px 8px rgba(244, 67, 54, 0.3);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@keyframes pulse {
|
||||||
|
0% {
|
||||||
|
transform: scale(1);
|
||||||
|
}
|
||||||
|
50% {
|
||||||
|
transform: scale(1.05);
|
||||||
|
}
|
||||||
|
100% {
|
||||||
|
transform: scale(1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#status {
|
#status {
|
||||||
margin-top: 10px;
|
text-align: center;
|
||||||
|
margin-top: 15px;
|
||||||
font-style: italic;
|
font-style: italic;
|
||||||
|
color: #657786;
|
||||||
|
}
|
||||||
|
|
||||||
|
.audio-wave {
|
||||||
|
display: flex;
|
||||||
|
justify-content: center;
|
||||||
|
align-items: center;
|
||||||
|
height: 40px;
|
||||||
|
gap: 3px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.audio-wave span {
|
||||||
|
display: block;
|
||||||
|
width: 3px;
|
||||||
|
height: 100%;
|
||||||
|
background-color: #4CAF50;
|
||||||
|
animation: wave 1.5s infinite ease-in-out;
|
||||||
|
border-radius: 6px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.audio-wave span:nth-child(2) {
|
||||||
|
animation-delay: 0.2s;
|
||||||
|
}
|
||||||
|
.audio-wave span:nth-child(3) {
|
||||||
|
animation-delay: 0.4s;
|
||||||
|
}
|
||||||
|
.audio-wave span:nth-child(4) {
|
||||||
|
animation-delay: 0.6s;
|
||||||
|
}
|
||||||
|
.audio-wave span:nth-child(5) {
|
||||||
|
animation-delay: 0.8s;
|
||||||
|
}
|
||||||
|
|
||||||
|
@keyframes wave {
|
||||||
|
0%, 100% {
|
||||||
|
height: 8px;
|
||||||
|
}
|
||||||
|
50% {
|
||||||
|
height: 30px;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
.hidden {
|
||||||
|
display: none;
|
||||||
|
}
|
||||||
|
|
||||||
|
.transcription-info {
|
||||||
|
font-size: 0.8em;
|
||||||
|
color: #888;
|
||||||
|
margin-top: 4px;
|
||||||
|
text-align: right;
|
||||||
}
|
}
|
||||||
</style>
|
</style>
|
||||||
</head>
|
</head>
|
||||||
<body>
|
<body>
|
||||||
<h1>Audio Conversation Bot</h1>
|
<h1>Voice Assistant with CSM & Whisper</h1>
|
||||||
<div id="conversation"></div>
|
<div id="conversation"></div>
|
||||||
|
|
||||||
<div id="controls">
|
<div id="controls">
|
||||||
<button id="recordButton">Hold to Speak</button>
|
<button id="recordButton">Hold to Speak</button>
|
||||||
</div>
|
</div>
|
||||||
<div id="status">Not connected</div>
|
|
||||||
|
<div id="audioWave" class="audio-wave hidden">
|
||||||
|
<span></span>
|
||||||
|
<span></span>
|
||||||
|
<span></span>
|
||||||
|
<span></span>
|
||||||
|
<span></span>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div id="status">Connecting to server...</div>
|
||||||
|
|
||||||
<script>
|
<script>
|
||||||
const socket = io();
|
const socket = io();
|
||||||
const recordButton = document.getElementById('recordButton');
|
const recordButton = document.getElementById('recordButton');
|
||||||
const conversation = document.getElementById('conversation');
|
const conversation = document.getElementById('conversation');
|
||||||
const status = document.getElementById('status');
|
const status = document.getElementById('status');
|
||||||
|
const audioWave = document.getElementById('audioWave');
|
||||||
|
|
||||||
let mediaRecorder;
|
let mediaRecorder;
|
||||||
let audioChunks = [];
|
let audioChunks = [];
|
||||||
let isRecording = false;
|
let isRecording = false;
|
||||||
|
let audioSendInterval;
|
||||||
|
let sessionActive = false;
|
||||||
|
|
||||||
// Initialize audio context and analyzer
|
// Initialize audio context
|
||||||
const audioContext = new (window.AudioContext || window.webkitAudioContext)();
|
const audioContext = new (window.AudioContext || window.webkitAudioContext)();
|
||||||
|
|
||||||
// Connect to server
|
// Connect to server
|
||||||
socket.on('connect', () => {
|
socket.on('connect', () => {
|
||||||
status.textContent = 'Connected to server';
|
status.textContent = 'Connected to server';
|
||||||
|
sessionActive = true;
|
||||||
|
});
|
||||||
|
|
||||||
|
socket.on('disconnect', () => {
|
||||||
|
status.textContent = 'Disconnected from server';
|
||||||
|
sessionActive = false;
|
||||||
});
|
});
|
||||||
|
|
||||||
socket.on('ready', (data) => {
|
socket.on('ready', (data) => {
|
||||||
@@ -90,28 +229,59 @@
|
|||||||
|
|
||||||
socket.on('transcription', (data) => {
|
socket.on('transcription', (data) => {
|
||||||
addMessage('user', data.text);
|
addMessage('user', data.text);
|
||||||
|
status.textContent = 'Assistant is thinking...';
|
||||||
});
|
});
|
||||||
|
|
||||||
socket.on('audio_response', (data) => {
|
socket.on('audio_response', (data) => {
|
||||||
// Play audio
|
// Play audio
|
||||||
|
status.textContent = 'Playing response...';
|
||||||
const audio = new Audio('data:audio/wav;base64,' + data.audio);
|
const audio = new Audio('data:audio/wav;base64,' + data.audio);
|
||||||
audio.play();
|
|
||||||
|
audio.onended = () => {
|
||||||
|
status.textContent = 'Ready to record';
|
||||||
|
};
|
||||||
|
|
||||||
|
audio.onerror = () => {
|
||||||
|
status.textContent = 'Error playing audio';
|
||||||
|
console.error('Error playing audio response');
|
||||||
|
};
|
||||||
|
|
||||||
|
audio.play().catch(err => {
|
||||||
|
status.textContent = 'Error playing audio: ' + err.message;
|
||||||
|
console.error('Error playing audio:', err);
|
||||||
|
});
|
||||||
|
|
||||||
// Display text
|
// Display text
|
||||||
addMessage('bot', data.text);
|
addMessage('bot', data.text);
|
||||||
});
|
});
|
||||||
|
|
||||||
socket.on('error', (data) => {
|
socket.on('error', (data) => {
|
||||||
status.textContent = data.message;
|
status.textContent = 'Error: ' + data.message;
|
||||||
console.error(data.message);
|
console.error('Server error:', data.message);
|
||||||
});
|
});
|
||||||
|
|
||||||
function setupAudioRecording() {
|
function setupAudioRecording() {
|
||||||
|
// Check if browser supports required APIs
|
||||||
|
if (!navigator.mediaDevices || !navigator.mediaDevices.getUserMedia) {
|
||||||
|
status.textContent = 'Your browser does not support audio recording';
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
// Get user media
|
// Get user media
|
||||||
navigator.mediaDevices.getUserMedia({ audio: true })
|
navigator.mediaDevices.getUserMedia({ audio: true })
|
||||||
.then(stream => {
|
.then(stream => {
|
||||||
// Setup recording
|
// Setup recording with better audio quality
|
||||||
mediaRecorder = new MediaRecorder(stream);
|
const options = {
|
||||||
|
mimeType: 'audio/webm',
|
||||||
|
audioBitsPerSecond: 128000
|
||||||
|
};
|
||||||
|
|
||||||
|
try {
|
||||||
|
mediaRecorder = new MediaRecorder(stream, options);
|
||||||
|
} catch (e) {
|
||||||
|
// Fallback if the specified options aren't supported
|
||||||
|
mediaRecorder = new MediaRecorder(stream);
|
||||||
|
}
|
||||||
|
|
||||||
mediaRecorder.ondataavailable = event => {
|
mediaRecorder.ondataavailable = event => {
|
||||||
if (event.data.size > 0) {
|
if (event.data.size > 0) {
|
||||||
@@ -120,36 +290,28 @@
|
|||||||
};
|
};
|
||||||
|
|
||||||
mediaRecorder.onstop = () => {
|
mediaRecorder.onstop = () => {
|
||||||
const audioBlob = new Blob(audioChunks, { type: 'audio/wav' });
|
processRecording();
|
||||||
audioChunks = [];
|
|
||||||
|
|
||||||
// Convert to Float32Array for sending
|
|
||||||
const fileReader = new FileReader();
|
|
||||||
fileReader.onloadend = () => {
|
|
||||||
const arrayBuffer = fileReader.result;
|
|
||||||
const floatArray = new Float32Array(arrayBuffer);
|
|
||||||
|
|
||||||
// Convert to base64
|
|
||||||
const base64String = arrayBufferToBase64(floatArray.buffer);
|
|
||||||
socket.emit('audio_chunk', { audio: base64String });
|
|
||||||
};
|
|
||||||
fileReader.readAsArrayBuffer(audioBlob);
|
|
||||||
|
|
||||||
socket.emit('stop_speaking');
|
|
||||||
isRecording = false;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// Setup audio analyzer for chunking and VAD
|
// Create audio analyzer for visualization
|
||||||
const source = audioContext.createMediaStreamSource(stream);
|
const source = audioContext.createMediaStreamSource(stream);
|
||||||
const analyzer = audioContext.createAnalyser();
|
const analyzer = audioContext.createAnalyser();
|
||||||
analyzer.fftSize = 2048;
|
analyzer.fftSize = 2048;
|
||||||
source.connect(analyzer);
|
source.connect(analyzer);
|
||||||
|
|
||||||
// Setup button handlers
|
// Setup button handlers with better touch handling
|
||||||
recordButton.addEventListener('mousedown', startRecording);
|
recordButton.addEventListener('mousedown', startRecording);
|
||||||
recordButton.addEventListener('touchstart', startRecording);
|
recordButton.addEventListener('touchstart', (e) => {
|
||||||
|
e.preventDefault(); // Prevent default touch behavior
|
||||||
|
startRecording();
|
||||||
|
});
|
||||||
|
|
||||||
recordButton.addEventListener('mouseup', stopRecording);
|
recordButton.addEventListener('mouseup', stopRecording);
|
||||||
recordButton.addEventListener('touchend', stopRecording);
|
recordButton.addEventListener('touchend', (e) => {
|
||||||
|
e.preventDefault();
|
||||||
|
stopRecording();
|
||||||
|
});
|
||||||
|
|
||||||
recordButton.addEventListener('mouseleave', stopRecording);
|
recordButton.addEventListener('mouseleave', stopRecording);
|
||||||
|
|
||||||
status.textContent = 'Ready to record';
|
status.textContent = 'Ready to record';
|
||||||
@@ -161,12 +323,13 @@
|
|||||||
}
|
}
|
||||||
|
|
||||||
function startRecording() {
|
function startRecording() {
|
||||||
if (!isRecording) {
|
if (!isRecording && sessionActive) {
|
||||||
audioChunks = [];
|
audioChunks = [];
|
||||||
mediaRecorder.start(100); // Collect data in 100ms chunks
|
mediaRecorder.start(100); // Collect data in 100ms chunks
|
||||||
recordButton.classList.add('recording');
|
recordButton.classList.add('recording');
|
||||||
recordButton.textContent = 'Release to Stop';
|
recordButton.textContent = 'Release to Stop';
|
||||||
status.textContent = 'Recording...';
|
status.textContent = 'Recording...';
|
||||||
|
audioWave.classList.remove('hidden');
|
||||||
isRecording = true;
|
isRecording = true;
|
||||||
|
|
||||||
socket.emit('start_speaking');
|
socket.emit('start_speaking');
|
||||||
@@ -186,15 +349,82 @@
|
|||||||
mediaRecorder.stop();
|
mediaRecorder.stop();
|
||||||
recordButton.classList.remove('recording');
|
recordButton.classList.remove('recording');
|
||||||
recordButton.textContent = 'Hold to Speak';
|
recordButton.textContent = 'Hold to Speak';
|
||||||
status.textContent = 'Processing...';
|
status.textContent = 'Processing speech...';
|
||||||
|
audioWave.classList.add('hidden');
|
||||||
|
isRecording = false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function processRecording() {
|
||||||
|
if (audioChunks.length === 0) {
|
||||||
|
status.textContent = 'No audio recorded';
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const audioBlob = new Blob(audioChunks, { type: 'audio/webm' });
|
||||||
|
|
||||||
|
// Convert to ArrayBuffer for processing
|
||||||
|
const fileReader = new FileReader();
|
||||||
|
fileReader.onloadend = () => {
|
||||||
|
try {
|
||||||
|
const arrayBuffer = fileReader.result;
|
||||||
|
// Convert to Float32Array - this works better with WebAudio API
|
||||||
|
const audioData = convertToFloat32(arrayBuffer);
|
||||||
|
|
||||||
|
// Convert to base64 for sending
|
||||||
|
const base64String = arrayBufferToBase64(audioData.buffer);
|
||||||
|
socket.emit('audio_chunk', { audio: base64String });
|
||||||
|
|
||||||
|
// Signal end of speech
|
||||||
|
socket.emit('stop_speaking');
|
||||||
|
} catch (e) {
|
||||||
|
console.error('Error processing audio:', e);
|
||||||
|
status.textContent = 'Error processing audio';
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
fileReader.onerror = () => {
|
||||||
|
status.textContent = 'Error reading audio data';
|
||||||
|
};
|
||||||
|
|
||||||
|
fileReader.readAsArrayBuffer(audioBlob);
|
||||||
|
}
|
||||||
|
|
||||||
|
function convertToFloat32(arrayBuffer) {
|
||||||
|
// Get raw audio data as Int16 (common format for audio)
|
||||||
|
const int16Array = new Int16Array(arrayBuffer);
|
||||||
|
|
||||||
|
// Convert to Float32 (normalize between -1 and 1)
|
||||||
|
const float32Array = new Float32Array(int16Array.length);
|
||||||
|
for (let i = 0; i < int16Array.length; i++) {
|
||||||
|
float32Array[i] = int16Array[i] / 32768.0;
|
||||||
|
}
|
||||||
|
|
||||||
|
return float32Array;
|
||||||
|
}
|
||||||
|
|
||||||
function addMessage(sender, text) {
|
function addMessage(sender, text) {
|
||||||
|
const containerDiv = document.createElement('div');
|
||||||
|
containerDiv.className = sender === 'user' ? 'message-container user-message-container' : 'message-container bot-message-container';
|
||||||
|
|
||||||
|
const labelDiv = document.createElement('div');
|
||||||
|
labelDiv.className = 'message-label';
|
||||||
|
labelDiv.textContent = sender === 'user' ? 'You' : 'Assistant';
|
||||||
|
containerDiv.appendChild(labelDiv);
|
||||||
|
|
||||||
const messageDiv = document.createElement('div');
|
const messageDiv = document.createElement('div');
|
||||||
messageDiv.className = sender === 'user' ? 'user-message' : 'bot-message';
|
messageDiv.className = sender === 'user' ? 'message user-message' : 'message bot-message';
|
||||||
messageDiv.textContent = text;
|
messageDiv.textContent = text;
|
||||||
conversation.appendChild(messageDiv);
|
containerDiv.appendChild(messageDiv);
|
||||||
|
|
||||||
|
if (sender === 'user') {
|
||||||
|
const infoDiv = document.createElement('div');
|
||||||
|
infoDiv.className = 'transcription-info';
|
||||||
|
infoDiv.textContent = 'Transcribed with Whisper';
|
||||||
|
containerDiv.appendChild(infoDiv);
|
||||||
|
}
|
||||||
|
|
||||||
|
conversation.appendChild(containerDiv);
|
||||||
conversation.scrollTop = conversation.scrollHeight;
|
conversation.scrollTop = conversation.scrollHeight;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -207,6 +437,20 @@
|
|||||||
}
|
}
|
||||||
return window.btoa(binary);
|
return window.btoa(binary);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Handle page visibility change to avoid issues with background tabs
|
||||||
|
document.addEventListener('visibilitychange', () => {
|
||||||
|
if (document.hidden && isRecording) {
|
||||||
|
stopRecording();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
// Clean disconnection when page is closed
|
||||||
|
window.addEventListener('beforeunload', () => {
|
||||||
|
if (socket && socket.connected) {
|
||||||
|
socket.disconnect();
|
||||||
|
}
|
||||||
|
});
|
||||||
</script>
|
</script>
|
||||||
</body>
|
</body>
|
||||||
</html>
|
</html>
|
||||||
1
Backend/req.txt
Normal file
1
Backend/req.txt
Normal file
@@ -0,0 +1 @@
|
|||||||
|
pip install faster-whisper
|
||||||
256
Backend/server.py
Normal file
256
Backend/server.py
Normal file
@@ -0,0 +1,256 @@
|
|||||||
|
import os
|
||||||
|
import io
|
||||||
|
import base64
|
||||||
|
import time
|
||||||
|
import torch
|
||||||
|
import torchaudio
|
||||||
|
import numpy as np
|
||||||
|
from flask import Flask, render_template, request
|
||||||
|
from flask_socketio import SocketIO, emit
|
||||||
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
from faster_whisper import WhisperModel
|
||||||
|
from generator import load_csm_1b, Segment
|
||||||
|
from collections import deque
|
||||||
|
|
||||||
|
app = Flask(__name__)
|
||||||
|
app.config['SECRET_KEY'] = 'your-secret-key'
|
||||||
|
socketio = SocketIO(app, cors_allowed_origins="*")
|
||||||
|
|
||||||
|
# Select the best available device
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = "cuda"
|
||||||
|
whisper_compute_type = "float16"
|
||||||
|
elif torch.backends.mps.is_available():
|
||||||
|
device = "mps"
|
||||||
|
whisper_compute_type = "float32"
|
||||||
|
else:
|
||||||
|
device = "cpu"
|
||||||
|
whisper_compute_type = "int8"
|
||||||
|
|
||||||
|
print(f"Using device: {device}")
|
||||||
|
|
||||||
|
# Initialize Faster-Whisper for transcription
|
||||||
|
print("Loading Whisper model...")
|
||||||
|
whisper_model = WhisperModel("base", device=device, compute_type=whisper_compute_type)
|
||||||
|
|
||||||
|
# Initialize CSM model for audio generation
|
||||||
|
print("Loading CSM model...")
|
||||||
|
csm_generator = load_csm_1b(device=device)
|
||||||
|
|
||||||
|
# Initialize Llama 3.2 model for response generation
|
||||||
|
print("Loading Llama 3.2 model...")
|
||||||
|
llm_model_id = "meta-llama/Llama-3.2-1B" # Choose appropriate size based on resources
|
||||||
|
llm_tokenizer = AutoTokenizer.from_pretrained(llm_model_id)
|
||||||
|
llm_model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
llm_model_id,
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device_map=device
|
||||||
|
)
|
||||||
|
|
||||||
|
# Store conversation context
|
||||||
|
conversation_context = {} # session_id -> context
|
||||||
|
|
||||||
|
@app.route('/')
|
||||||
|
def index():
|
||||||
|
return render_template('index.html')
|
||||||
|
|
||||||
|
@socketio.on('connect')
|
||||||
|
def handle_connect():
|
||||||
|
print(f"Client connected: {request.sid}")
|
||||||
|
conversation_context[request.sid] = {
|
||||||
|
'segments': [],
|
||||||
|
'speakers': [0, 1], # 0 = user, 1 = bot
|
||||||
|
'audio_buffer': deque(maxlen=10), # Store recent audio chunks
|
||||||
|
'is_speaking': False,
|
||||||
|
'silence_start': None
|
||||||
|
}
|
||||||
|
emit('ready', {'message': 'Connection established'})
|
||||||
|
|
||||||
|
@socketio.on('disconnect')
|
||||||
|
def handle_disconnect():
|
||||||
|
print(f"Client disconnected: {request.sid}")
|
||||||
|
if request.sid in conversation_context:
|
||||||
|
del conversation_context[request.sid]
|
||||||
|
|
||||||
|
@socketio.on('start_speaking')
|
||||||
|
def handle_start_speaking():
|
||||||
|
if request.sid in conversation_context:
|
||||||
|
conversation_context[request.sid]['is_speaking'] = True
|
||||||
|
conversation_context[request.sid]['audio_buffer'].clear()
|
||||||
|
print(f"User {request.sid} started speaking")
|
||||||
|
|
||||||
|
@socketio.on('audio_chunk')
|
||||||
|
def handle_audio_chunk(data):
|
||||||
|
if request.sid not in conversation_context:
|
||||||
|
return
|
||||||
|
|
||||||
|
context = conversation_context[request.sid]
|
||||||
|
|
||||||
|
# Decode audio data
|
||||||
|
audio_data = base64.b64decode(data['audio'])
|
||||||
|
audio_numpy = np.frombuffer(audio_data, dtype=np.float32)
|
||||||
|
audio_tensor = torch.tensor(audio_numpy)
|
||||||
|
|
||||||
|
# Add to buffer
|
||||||
|
context['audio_buffer'].append(audio_tensor)
|
||||||
|
|
||||||
|
# Check for silence to detect end of speech
|
||||||
|
if context['is_speaking'] and is_silence(audio_tensor):
|
||||||
|
if context['silence_start'] is None:
|
||||||
|
context['silence_start'] = time.time()
|
||||||
|
elif time.time() - context['silence_start'] > 1.0: # 1 second of silence
|
||||||
|
# Process the complete utterance
|
||||||
|
process_user_utterance(request.sid)
|
||||||
|
else:
|
||||||
|
context['silence_start'] = None
|
||||||
|
|
||||||
|
@socketio.on('stop_speaking')
|
||||||
|
def handle_stop_speaking():
|
||||||
|
if request.sid in conversation_context:
|
||||||
|
conversation_context[request.sid]['is_speaking'] = False
|
||||||
|
process_user_utterance(request.sid)
|
||||||
|
print(f"User {request.sid} stopped speaking")
|
||||||
|
|
||||||
|
def is_silence(audio_tensor, threshold=0.02):
|
||||||
|
"""Check if an audio chunk is silence based on amplitude threshold"""
|
||||||
|
return torch.mean(torch.abs(audio_tensor)) < threshold
|
||||||
|
|
||||||
|
def process_user_utterance(session_id):
|
||||||
|
"""Process completed user utterance, generate response and send audio back"""
|
||||||
|
context = conversation_context[session_id]
|
||||||
|
|
||||||
|
if not context['audio_buffer']:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Combine audio chunks
|
||||||
|
full_audio = torch.cat(list(context['audio_buffer']), dim=0)
|
||||||
|
context['audio_buffer'].clear()
|
||||||
|
context['is_speaking'] = False
|
||||||
|
context['silence_start'] = None
|
||||||
|
|
||||||
|
# Save audio to temporary WAV file for Whisper transcription
|
||||||
|
temp_audio_path = f"temp_audio_{session_id}.wav"
|
||||||
|
torchaudio.save(
|
||||||
|
temp_audio_path,
|
||||||
|
full_audio.unsqueeze(0),
|
||||||
|
44100 # Assuming 44.1kHz from client
|
||||||
|
)
|
||||||
|
|
||||||
|
# Transcribe speech using Faster-Whisper
|
||||||
|
try:
|
||||||
|
segments, info = whisper_model.transcribe(temp_audio_path, beam_size=5)
|
||||||
|
|
||||||
|
# Collect all text from segments
|
||||||
|
user_text = ""
|
||||||
|
for segment in segments:
|
||||||
|
segment_text = segment.text.strip()
|
||||||
|
print(f"[{segment.start:.2f}s -> {segment.end:.2f}s] {segment_text}")
|
||||||
|
user_text += segment_text + " "
|
||||||
|
|
||||||
|
user_text = user_text.strip()
|
||||||
|
|
||||||
|
# Cleanup temp file
|
||||||
|
if os.path.exists(temp_audio_path):
|
||||||
|
os.remove(temp_audio_path)
|
||||||
|
|
||||||
|
if not user_text:
|
||||||
|
print("No speech detected.")
|
||||||
|
return
|
||||||
|
|
||||||
|
print(f"Transcribed: {user_text}")
|
||||||
|
|
||||||
|
# Add to conversation segments
|
||||||
|
user_segment = Segment(
|
||||||
|
text=user_text,
|
||||||
|
speaker=0, # User is speaker 0
|
||||||
|
audio=full_audio
|
||||||
|
)
|
||||||
|
context['segments'].append(user_segment)
|
||||||
|
|
||||||
|
# Generate bot response
|
||||||
|
bot_response = generate_llm_response(user_text, context['segments'])
|
||||||
|
print(f"Bot response: {bot_response}")
|
||||||
|
|
||||||
|
# Convert to audio using CSM
|
||||||
|
bot_audio = generate_audio_response(bot_response, context['segments'])
|
||||||
|
|
||||||
|
# Convert audio to base64 for sending over websocket
|
||||||
|
audio_bytes = io.BytesIO()
|
||||||
|
torchaudio.save(audio_bytes, bot_audio.unsqueeze(0).cpu(), csm_generator.sample_rate, format="wav")
|
||||||
|
audio_bytes.seek(0)
|
||||||
|
audio_b64 = base64.b64encode(audio_bytes.read()).decode('utf-8')
|
||||||
|
|
||||||
|
# Add bot response to conversation history
|
||||||
|
bot_segment = Segment(
|
||||||
|
text=bot_response,
|
||||||
|
speaker=1, # Bot is speaker 1
|
||||||
|
audio=bot_audio
|
||||||
|
)
|
||||||
|
context['segments'].append(bot_segment)
|
||||||
|
|
||||||
|
# Send transcribed text to client
|
||||||
|
emit('transcription', {'text': user_text}, room=session_id)
|
||||||
|
|
||||||
|
# Send audio response to client
|
||||||
|
emit('audio_response', {
|
||||||
|
'audio': audio_b64,
|
||||||
|
'text': bot_response
|
||||||
|
}, room=session_id)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error processing speech: {e}")
|
||||||
|
emit('error', {'message': f'Error processing speech: {str(e)}'}, room=session_id)
|
||||||
|
# Cleanup temp file in case of error
|
||||||
|
if os.path.exists(temp_audio_path):
|
||||||
|
os.remove(temp_audio_path)
|
||||||
|
|
||||||
|
def generate_llm_response(user_text, conversation_segments):
|
||||||
|
"""Generate text response using Llama 3.2"""
|
||||||
|
# Format conversation history for the LLM
|
||||||
|
conversation_history = ""
|
||||||
|
for segment in conversation_segments[-5:]: # Use last 5 utterances for context
|
||||||
|
speaker_name = "User" if segment.speaker == 0 else "Assistant"
|
||||||
|
conversation_history += f"{speaker_name}: {segment.text}\n"
|
||||||
|
|
||||||
|
# Add the current user query
|
||||||
|
conversation_history += f"User: {user_text}\nAssistant:"
|
||||||
|
|
||||||
|
# Generate response
|
||||||
|
inputs = llm_tokenizer(conversation_history, return_tensors="pt").to(device)
|
||||||
|
output = llm_model.generate(
|
||||||
|
inputs.input_ids,
|
||||||
|
max_new_tokens=150,
|
||||||
|
temperature=0.7,
|
||||||
|
top_p=0.9,
|
||||||
|
do_sample=True
|
||||||
|
)
|
||||||
|
|
||||||
|
response = llm_tokenizer.decode(output[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
|
||||||
|
return response.strip()
|
||||||
|
|
||||||
|
def generate_audio_response(text, conversation_segments):
|
||||||
|
"""Generate audio response using CSM"""
|
||||||
|
# Use the last few conversation segments as context
|
||||||
|
context_segments = conversation_segments[-4:] if len(conversation_segments) > 4 else conversation_segments
|
||||||
|
|
||||||
|
# Generate audio for bot response
|
||||||
|
audio = csm_generator.generate(
|
||||||
|
text=text,
|
||||||
|
speaker=1, # Bot is speaker 1
|
||||||
|
context=context_segments,
|
||||||
|
max_audio_length_ms=10000, # 10 seconds max
|
||||||
|
temperature=0.9,
|
||||||
|
topk=50
|
||||||
|
)
|
||||||
|
|
||||||
|
return audio
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# Ensure the existing index.html file is in the correct location
|
||||||
|
if not os.path.exists('templates'):
|
||||||
|
os.makedirs('templates')
|
||||||
|
|
||||||
|
if os.path.exists('index.html') and not os.path.exists('templates/index.html'):
|
||||||
|
os.rename('index.html', 'templates/index.html')
|
||||||
|
|
||||||
|
socketio.run(app, host='0.0.0.0', port=5000, debug=False)
|
||||||
Reference in New Issue
Block a user