Demo Update 21
This commit is contained in:
@@ -1,86 +1,225 @@
|
||||
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>Audio Conversation Bot</title>
|
||||
<title>Voice Assistant - CSM & Whisper</title>
|
||||
<script src="https://cdn.socket.io/4.6.0/socket.io.min.js"></script>
|
||||
<style>
|
||||
body {
|
||||
font-family: Arial, sans-serif;
|
||||
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
|
||||
max-width: 800px;
|
||||
margin: 0 auto;
|
||||
padding: 20px;
|
||||
background-color: #f5f7fa;
|
||||
color: #333;
|
||||
}
|
||||
|
||||
h1 {
|
||||
color: #2c3e50;
|
||||
text-align: center;
|
||||
margin-bottom: 30px;
|
||||
}
|
||||
|
||||
#conversation {
|
||||
height: 400px;
|
||||
border: 1px solid #ccc;
|
||||
padding: 15px;
|
||||
border: 1px solid #ddd;
|
||||
border-radius: 10px;
|
||||
padding: 20px;
|
||||
margin-bottom: 20px;
|
||||
overflow-y: auto;
|
||||
background-color: white;
|
||||
box-shadow: 0 2px 5px rgba(0,0,0,0.1);
|
||||
}
|
||||
|
||||
.message-container {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
margin-bottom: 15px;
|
||||
}
|
||||
|
||||
.user-message-container {
|
||||
align-items: flex-end;
|
||||
}
|
||||
|
||||
.bot-message-container {
|
||||
align-items: flex-start;
|
||||
}
|
||||
|
||||
.message {
|
||||
max-width: 80%;
|
||||
padding: 12px;
|
||||
border-radius: 18px;
|
||||
position: relative;
|
||||
word-break: break-word;
|
||||
}
|
||||
|
||||
.user-message {
|
||||
background-color: #e1f5fe;
|
||||
padding: 10px;
|
||||
border-radius: 8px;
|
||||
margin-bottom: 10px;
|
||||
align-self: flex-end;
|
||||
background-color: #dcf8c6;
|
||||
color: #000;
|
||||
border-bottom-right-radius: 4px;
|
||||
}
|
||||
|
||||
.bot-message {
|
||||
background-color: #f1f1f1;
|
||||
padding: 10px;
|
||||
border-radius: 8px;
|
||||
margin-bottom: 10px;
|
||||
background-color: #f1f0f0;
|
||||
color: #000;
|
||||
border-bottom-left-radius: 4px;
|
||||
}
|
||||
|
||||
.message-label {
|
||||
font-size: 0.8em;
|
||||
margin-bottom: 4px;
|
||||
color: #657786;
|
||||
}
|
||||
|
||||
#controls {
|
||||
display: flex;
|
||||
gap: 10px;
|
||||
justify-content: center;
|
||||
margin-bottom: 15px;
|
||||
}
|
||||
|
||||
button {
|
||||
padding: 10px 20px;
|
||||
padding: 12px 24px;
|
||||
font-size: 16px;
|
||||
cursor: pointer;
|
||||
border-radius: 50px;
|
||||
border: none;
|
||||
outline: none;
|
||||
transition: all 0.3s ease;
|
||||
}
|
||||
|
||||
#recordButton {
|
||||
background-color: #4CAF50;
|
||||
color: white;
|
||||
border: none;
|
||||
border-radius: 4px;
|
||||
width: 200px;
|
||||
box-shadow: 0 4px 8px rgba(76, 175, 80, 0.3);
|
||||
}
|
||||
|
||||
#recordButton:hover {
|
||||
background-color: #45a049;
|
||||
transform: translateY(-2px);
|
||||
}
|
||||
|
||||
#recordButton.recording {
|
||||
background-color: #f44336;
|
||||
animation: pulse 1.5s infinite;
|
||||
box-shadow: 0 4px 8px rgba(244, 67, 54, 0.3);
|
||||
}
|
||||
|
||||
@keyframes pulse {
|
||||
0% {
|
||||
transform: scale(1);
|
||||
}
|
||||
50% {
|
||||
transform: scale(1.05);
|
||||
}
|
||||
100% {
|
||||
transform: scale(1);
|
||||
}
|
||||
}
|
||||
|
||||
#status {
|
||||
margin-top: 10px;
|
||||
text-align: center;
|
||||
margin-top: 15px;
|
||||
font-style: italic;
|
||||
color: #657786;
|
||||
}
|
||||
|
||||
.audio-wave {
|
||||
display: flex;
|
||||
justify-content: center;
|
||||
align-items: center;
|
||||
height: 40px;
|
||||
gap: 3px;
|
||||
}
|
||||
|
||||
.audio-wave span {
|
||||
display: block;
|
||||
width: 3px;
|
||||
height: 100%;
|
||||
background-color: #4CAF50;
|
||||
animation: wave 1.5s infinite ease-in-out;
|
||||
border-radius: 6px;
|
||||
}
|
||||
|
||||
.audio-wave span:nth-child(2) {
|
||||
animation-delay: 0.2s;
|
||||
}
|
||||
.audio-wave span:nth-child(3) {
|
||||
animation-delay: 0.4s;
|
||||
}
|
||||
.audio-wave span:nth-child(4) {
|
||||
animation-delay: 0.6s;
|
||||
}
|
||||
.audio-wave span:nth-child(5) {
|
||||
animation-delay: 0.8s;
|
||||
}
|
||||
|
||||
@keyframes wave {
|
||||
0%, 100% {
|
||||
height: 8px;
|
||||
}
|
||||
50% {
|
||||
height: 30px;
|
||||
}
|
||||
}
|
||||
|
||||
.hidden {
|
||||
display: none;
|
||||
}
|
||||
|
||||
.transcription-info {
|
||||
font-size: 0.8em;
|
||||
color: #888;
|
||||
margin-top: 4px;
|
||||
text-align: right;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<h1>Audio Conversation Bot</h1>
|
||||
<h1>Voice Assistant with CSM & Whisper</h1>
|
||||
<div id="conversation"></div>
|
||||
|
||||
<div id="controls">
|
||||
<button id="recordButton">Hold to Speak</button>
|
||||
</div>
|
||||
<div id="status">Not connected</div>
|
||||
|
||||
<div id="audioWave" class="audio-wave hidden">
|
||||
<span></span>
|
||||
<span></span>
|
||||
<span></span>
|
||||
<span></span>
|
||||
<span></span>
|
||||
</div>
|
||||
|
||||
<div id="status">Connecting to server...</div>
|
||||
|
||||
<script>
|
||||
const socket = io();
|
||||
const recordButton = document.getElementById('recordButton');
|
||||
const conversation = document.getElementById('conversation');
|
||||
const status = document.getElementById('status');
|
||||
const audioWave = document.getElementById('audioWave');
|
||||
|
||||
let mediaRecorder;
|
||||
let audioChunks = [];
|
||||
let isRecording = false;
|
||||
let audioSendInterval;
|
||||
let sessionActive = false;
|
||||
|
||||
// Initialize audio context and analyzer
|
||||
// Initialize audio context
|
||||
const audioContext = new (window.AudioContext || window.webkitAudioContext)();
|
||||
|
||||
// Connect to server
|
||||
socket.on('connect', () => {
|
||||
status.textContent = 'Connected to server';
|
||||
sessionActive = true;
|
||||
});
|
||||
|
||||
socket.on('disconnect', () => {
|
||||
status.textContent = 'Disconnected from server';
|
||||
sessionActive = false;
|
||||
});
|
||||
|
||||
socket.on('ready', (data) => {
|
||||
@@ -90,28 +229,59 @@
|
||||
|
||||
socket.on('transcription', (data) => {
|
||||
addMessage('user', data.text);
|
||||
status.textContent = 'Assistant is thinking...';
|
||||
});
|
||||
|
||||
socket.on('audio_response', (data) => {
|
||||
// Play audio
|
||||
status.textContent = 'Playing response...';
|
||||
const audio = new Audio('data:audio/wav;base64,' + data.audio);
|
||||
audio.play();
|
||||
|
||||
audio.onended = () => {
|
||||
status.textContent = 'Ready to record';
|
||||
};
|
||||
|
||||
audio.onerror = () => {
|
||||
status.textContent = 'Error playing audio';
|
||||
console.error('Error playing audio response');
|
||||
};
|
||||
|
||||
audio.play().catch(err => {
|
||||
status.textContent = 'Error playing audio: ' + err.message;
|
||||
console.error('Error playing audio:', err);
|
||||
});
|
||||
|
||||
// Display text
|
||||
addMessage('bot', data.text);
|
||||
});
|
||||
|
||||
socket.on('error', (data) => {
|
||||
status.textContent = data.message;
|
||||
console.error(data.message);
|
||||
status.textContent = 'Error: ' + data.message;
|
||||
console.error('Server error:', data.message);
|
||||
});
|
||||
|
||||
function setupAudioRecording() {
|
||||
// Check if browser supports required APIs
|
||||
if (!navigator.mediaDevices || !navigator.mediaDevices.getUserMedia) {
|
||||
status.textContent = 'Your browser does not support audio recording';
|
||||
return;
|
||||
}
|
||||
|
||||
// Get user media
|
||||
navigator.mediaDevices.getUserMedia({ audio: true })
|
||||
.then(stream => {
|
||||
// Setup recording
|
||||
// Setup recording with better audio quality
|
||||
const options = {
|
||||
mimeType: 'audio/webm',
|
||||
audioBitsPerSecond: 128000
|
||||
};
|
||||
|
||||
try {
|
||||
mediaRecorder = new MediaRecorder(stream, options);
|
||||
} catch (e) {
|
||||
// Fallback if the specified options aren't supported
|
||||
mediaRecorder = new MediaRecorder(stream);
|
||||
}
|
||||
|
||||
mediaRecorder.ondataavailable = event => {
|
||||
if (event.data.size > 0) {
|
||||
@@ -120,36 +290,28 @@
|
||||
};
|
||||
|
||||
mediaRecorder.onstop = () => {
|
||||
const audioBlob = new Blob(audioChunks, { type: 'audio/wav' });
|
||||
audioChunks = [];
|
||||
|
||||
// Convert to Float32Array for sending
|
||||
const fileReader = new FileReader();
|
||||
fileReader.onloadend = () => {
|
||||
const arrayBuffer = fileReader.result;
|
||||
const floatArray = new Float32Array(arrayBuffer);
|
||||
|
||||
// Convert to base64
|
||||
const base64String = arrayBufferToBase64(floatArray.buffer);
|
||||
socket.emit('audio_chunk', { audio: base64String });
|
||||
};
|
||||
fileReader.readAsArrayBuffer(audioBlob);
|
||||
|
||||
socket.emit('stop_speaking');
|
||||
isRecording = false;
|
||||
processRecording();
|
||||
};
|
||||
|
||||
// Setup audio analyzer for chunking and VAD
|
||||
// Create audio analyzer for visualization
|
||||
const source = audioContext.createMediaStreamSource(stream);
|
||||
const analyzer = audioContext.createAnalyser();
|
||||
analyzer.fftSize = 2048;
|
||||
source.connect(analyzer);
|
||||
|
||||
// Setup button handlers
|
||||
// Setup button handlers with better touch handling
|
||||
recordButton.addEventListener('mousedown', startRecording);
|
||||
recordButton.addEventListener('touchstart', startRecording);
|
||||
recordButton.addEventListener('touchstart', (e) => {
|
||||
e.preventDefault(); // Prevent default touch behavior
|
||||
startRecording();
|
||||
});
|
||||
|
||||
recordButton.addEventListener('mouseup', stopRecording);
|
||||
recordButton.addEventListener('touchend', stopRecording);
|
||||
recordButton.addEventListener('touchend', (e) => {
|
||||
e.preventDefault();
|
||||
stopRecording();
|
||||
});
|
||||
|
||||
recordButton.addEventListener('mouseleave', stopRecording);
|
||||
|
||||
status.textContent = 'Ready to record';
|
||||
@@ -161,12 +323,13 @@
|
||||
}
|
||||
|
||||
function startRecording() {
|
||||
if (!isRecording) {
|
||||
if (!isRecording && sessionActive) {
|
||||
audioChunks = [];
|
||||
mediaRecorder.start(100); // Collect data in 100ms chunks
|
||||
recordButton.classList.add('recording');
|
||||
recordButton.textContent = 'Release to Stop';
|
||||
status.textContent = 'Recording...';
|
||||
audioWave.classList.remove('hidden');
|
||||
isRecording = true;
|
||||
|
||||
socket.emit('start_speaking');
|
||||
@@ -186,15 +349,82 @@
|
||||
mediaRecorder.stop();
|
||||
recordButton.classList.remove('recording');
|
||||
recordButton.textContent = 'Hold to Speak';
|
||||
status.textContent = 'Processing...';
|
||||
status.textContent = 'Processing speech...';
|
||||
audioWave.classList.add('hidden');
|
||||
isRecording = false;
|
||||
}
|
||||
}
|
||||
|
||||
function processRecording() {
|
||||
if (audioChunks.length === 0) {
|
||||
status.textContent = 'No audio recorded';
|
||||
return;
|
||||
}
|
||||
|
||||
const audioBlob = new Blob(audioChunks, { type: 'audio/webm' });
|
||||
|
||||
// Convert to ArrayBuffer for processing
|
||||
const fileReader = new FileReader();
|
||||
fileReader.onloadend = () => {
|
||||
try {
|
||||
const arrayBuffer = fileReader.result;
|
||||
// Convert to Float32Array - this works better with WebAudio API
|
||||
const audioData = convertToFloat32(arrayBuffer);
|
||||
|
||||
// Convert to base64 for sending
|
||||
const base64String = arrayBufferToBase64(audioData.buffer);
|
||||
socket.emit('audio_chunk', { audio: base64String });
|
||||
|
||||
// Signal end of speech
|
||||
socket.emit('stop_speaking');
|
||||
} catch (e) {
|
||||
console.error('Error processing audio:', e);
|
||||
status.textContent = 'Error processing audio';
|
||||
}
|
||||
};
|
||||
|
||||
fileReader.onerror = () => {
|
||||
status.textContent = 'Error reading audio data';
|
||||
};
|
||||
|
||||
fileReader.readAsArrayBuffer(audioBlob);
|
||||
}
|
||||
|
||||
function convertToFloat32(arrayBuffer) {
|
||||
// Get raw audio data as Int16 (common format for audio)
|
||||
const int16Array = new Int16Array(arrayBuffer);
|
||||
|
||||
// Convert to Float32 (normalize between -1 and 1)
|
||||
const float32Array = new Float32Array(int16Array.length);
|
||||
for (let i = 0; i < int16Array.length; i++) {
|
||||
float32Array[i] = int16Array[i] / 32768.0;
|
||||
}
|
||||
|
||||
return float32Array;
|
||||
}
|
||||
|
||||
function addMessage(sender, text) {
|
||||
const containerDiv = document.createElement('div');
|
||||
containerDiv.className = sender === 'user' ? 'message-container user-message-container' : 'message-container bot-message-container';
|
||||
|
||||
const labelDiv = document.createElement('div');
|
||||
labelDiv.className = 'message-label';
|
||||
labelDiv.textContent = sender === 'user' ? 'You' : 'Assistant';
|
||||
containerDiv.appendChild(labelDiv);
|
||||
|
||||
const messageDiv = document.createElement('div');
|
||||
messageDiv.className = sender === 'user' ? 'user-message' : 'bot-message';
|
||||
messageDiv.className = sender === 'user' ? 'message user-message' : 'message bot-message';
|
||||
messageDiv.textContent = text;
|
||||
conversation.appendChild(messageDiv);
|
||||
containerDiv.appendChild(messageDiv);
|
||||
|
||||
if (sender === 'user') {
|
||||
const infoDiv = document.createElement('div');
|
||||
infoDiv.className = 'transcription-info';
|
||||
infoDiv.textContent = 'Transcribed with Whisper';
|
||||
containerDiv.appendChild(infoDiv);
|
||||
}
|
||||
|
||||
conversation.appendChild(containerDiv);
|
||||
conversation.scrollTop = conversation.scrollHeight;
|
||||
}
|
||||
|
||||
@@ -207,6 +437,20 @@
|
||||
}
|
||||
return window.btoa(binary);
|
||||
}
|
||||
|
||||
// Handle page visibility change to avoid issues with background tabs
|
||||
document.addEventListener('visibilitychange', () => {
|
||||
if (document.hidden && isRecording) {
|
||||
stopRecording();
|
||||
}
|
||||
});
|
||||
|
||||
// Clean disconnection when page is closed
|
||||
window.addEventListener('beforeunload', () => {
|
||||
if (socket && socket.connected) {
|
||||
socket.disconnect();
|
||||
}
|
||||
});
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
1
Backend/req.txt
Normal file
1
Backend/req.txt
Normal file
@@ -0,0 +1 @@
|
||||
pip install faster-whisper
|
||||
256
Backend/server.py
Normal file
256
Backend/server.py
Normal file
@@ -0,0 +1,256 @@
|
||||
import os
|
||||
import io
|
||||
import base64
|
||||
import time
|
||||
import torch
|
||||
import torchaudio
|
||||
import numpy as np
|
||||
from flask import Flask, render_template, request
|
||||
from flask_socketio import SocketIO, emit
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from faster_whisper import WhisperModel
|
||||
from generator import load_csm_1b, Segment
|
||||
from collections import deque
|
||||
|
||||
app = Flask(__name__)
|
||||
app.config['SECRET_KEY'] = 'your-secret-key'
|
||||
socketio = SocketIO(app, cors_allowed_origins="*")
|
||||
|
||||
# Select the best available device
|
||||
if torch.cuda.is_available():
|
||||
device = "cuda"
|
||||
whisper_compute_type = "float16"
|
||||
elif torch.backends.mps.is_available():
|
||||
device = "mps"
|
||||
whisper_compute_type = "float32"
|
||||
else:
|
||||
device = "cpu"
|
||||
whisper_compute_type = "int8"
|
||||
|
||||
print(f"Using device: {device}")
|
||||
|
||||
# Initialize Faster-Whisper for transcription
|
||||
print("Loading Whisper model...")
|
||||
whisper_model = WhisperModel("base", device=device, compute_type=whisper_compute_type)
|
||||
|
||||
# Initialize CSM model for audio generation
|
||||
print("Loading CSM model...")
|
||||
csm_generator = load_csm_1b(device=device)
|
||||
|
||||
# Initialize Llama 3.2 model for response generation
|
||||
print("Loading Llama 3.2 model...")
|
||||
llm_model_id = "meta-llama/Llama-3.2-1B" # Choose appropriate size based on resources
|
||||
llm_tokenizer = AutoTokenizer.from_pretrained(llm_model_id)
|
||||
llm_model = AutoModelForCausalLM.from_pretrained(
|
||||
llm_model_id,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map=device
|
||||
)
|
||||
|
||||
# Store conversation context
|
||||
conversation_context = {} # session_id -> context
|
||||
|
||||
@app.route('/')
|
||||
def index():
|
||||
return render_template('index.html')
|
||||
|
||||
@socketio.on('connect')
|
||||
def handle_connect():
|
||||
print(f"Client connected: {request.sid}")
|
||||
conversation_context[request.sid] = {
|
||||
'segments': [],
|
||||
'speakers': [0, 1], # 0 = user, 1 = bot
|
||||
'audio_buffer': deque(maxlen=10), # Store recent audio chunks
|
||||
'is_speaking': False,
|
||||
'silence_start': None
|
||||
}
|
||||
emit('ready', {'message': 'Connection established'})
|
||||
|
||||
@socketio.on('disconnect')
|
||||
def handle_disconnect():
|
||||
print(f"Client disconnected: {request.sid}")
|
||||
if request.sid in conversation_context:
|
||||
del conversation_context[request.sid]
|
||||
|
||||
@socketio.on('start_speaking')
|
||||
def handle_start_speaking():
|
||||
if request.sid in conversation_context:
|
||||
conversation_context[request.sid]['is_speaking'] = True
|
||||
conversation_context[request.sid]['audio_buffer'].clear()
|
||||
print(f"User {request.sid} started speaking")
|
||||
|
||||
@socketio.on('audio_chunk')
|
||||
def handle_audio_chunk(data):
|
||||
if request.sid not in conversation_context:
|
||||
return
|
||||
|
||||
context = conversation_context[request.sid]
|
||||
|
||||
# Decode audio data
|
||||
audio_data = base64.b64decode(data['audio'])
|
||||
audio_numpy = np.frombuffer(audio_data, dtype=np.float32)
|
||||
audio_tensor = torch.tensor(audio_numpy)
|
||||
|
||||
# Add to buffer
|
||||
context['audio_buffer'].append(audio_tensor)
|
||||
|
||||
# Check for silence to detect end of speech
|
||||
if context['is_speaking'] and is_silence(audio_tensor):
|
||||
if context['silence_start'] is None:
|
||||
context['silence_start'] = time.time()
|
||||
elif time.time() - context['silence_start'] > 1.0: # 1 second of silence
|
||||
# Process the complete utterance
|
||||
process_user_utterance(request.sid)
|
||||
else:
|
||||
context['silence_start'] = None
|
||||
|
||||
@socketio.on('stop_speaking')
|
||||
def handle_stop_speaking():
|
||||
if request.sid in conversation_context:
|
||||
conversation_context[request.sid]['is_speaking'] = False
|
||||
process_user_utterance(request.sid)
|
||||
print(f"User {request.sid} stopped speaking")
|
||||
|
||||
def is_silence(audio_tensor, threshold=0.02):
|
||||
"""Check if an audio chunk is silence based on amplitude threshold"""
|
||||
return torch.mean(torch.abs(audio_tensor)) < threshold
|
||||
|
||||
def process_user_utterance(session_id):
|
||||
"""Process completed user utterance, generate response and send audio back"""
|
||||
context = conversation_context[session_id]
|
||||
|
||||
if not context['audio_buffer']:
|
||||
return
|
||||
|
||||
# Combine audio chunks
|
||||
full_audio = torch.cat(list(context['audio_buffer']), dim=0)
|
||||
context['audio_buffer'].clear()
|
||||
context['is_speaking'] = False
|
||||
context['silence_start'] = None
|
||||
|
||||
# Save audio to temporary WAV file for Whisper transcription
|
||||
temp_audio_path = f"temp_audio_{session_id}.wav"
|
||||
torchaudio.save(
|
||||
temp_audio_path,
|
||||
full_audio.unsqueeze(0),
|
||||
44100 # Assuming 44.1kHz from client
|
||||
)
|
||||
|
||||
# Transcribe speech using Faster-Whisper
|
||||
try:
|
||||
segments, info = whisper_model.transcribe(temp_audio_path, beam_size=5)
|
||||
|
||||
# Collect all text from segments
|
||||
user_text = ""
|
||||
for segment in segments:
|
||||
segment_text = segment.text.strip()
|
||||
print(f"[{segment.start:.2f}s -> {segment.end:.2f}s] {segment_text}")
|
||||
user_text += segment_text + " "
|
||||
|
||||
user_text = user_text.strip()
|
||||
|
||||
# Cleanup temp file
|
||||
if os.path.exists(temp_audio_path):
|
||||
os.remove(temp_audio_path)
|
||||
|
||||
if not user_text:
|
||||
print("No speech detected.")
|
||||
return
|
||||
|
||||
print(f"Transcribed: {user_text}")
|
||||
|
||||
# Add to conversation segments
|
||||
user_segment = Segment(
|
||||
text=user_text,
|
||||
speaker=0, # User is speaker 0
|
||||
audio=full_audio
|
||||
)
|
||||
context['segments'].append(user_segment)
|
||||
|
||||
# Generate bot response
|
||||
bot_response = generate_llm_response(user_text, context['segments'])
|
||||
print(f"Bot response: {bot_response}")
|
||||
|
||||
# Convert to audio using CSM
|
||||
bot_audio = generate_audio_response(bot_response, context['segments'])
|
||||
|
||||
# Convert audio to base64 for sending over websocket
|
||||
audio_bytes = io.BytesIO()
|
||||
torchaudio.save(audio_bytes, bot_audio.unsqueeze(0).cpu(), csm_generator.sample_rate, format="wav")
|
||||
audio_bytes.seek(0)
|
||||
audio_b64 = base64.b64encode(audio_bytes.read()).decode('utf-8')
|
||||
|
||||
# Add bot response to conversation history
|
||||
bot_segment = Segment(
|
||||
text=bot_response,
|
||||
speaker=1, # Bot is speaker 1
|
||||
audio=bot_audio
|
||||
)
|
||||
context['segments'].append(bot_segment)
|
||||
|
||||
# Send transcribed text to client
|
||||
emit('transcription', {'text': user_text}, room=session_id)
|
||||
|
||||
# Send audio response to client
|
||||
emit('audio_response', {
|
||||
'audio': audio_b64,
|
||||
'text': bot_response
|
||||
}, room=session_id)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing speech: {e}")
|
||||
emit('error', {'message': f'Error processing speech: {str(e)}'}, room=session_id)
|
||||
# Cleanup temp file in case of error
|
||||
if os.path.exists(temp_audio_path):
|
||||
os.remove(temp_audio_path)
|
||||
|
||||
def generate_llm_response(user_text, conversation_segments):
|
||||
"""Generate text response using Llama 3.2"""
|
||||
# Format conversation history for the LLM
|
||||
conversation_history = ""
|
||||
for segment in conversation_segments[-5:]: # Use last 5 utterances for context
|
||||
speaker_name = "User" if segment.speaker == 0 else "Assistant"
|
||||
conversation_history += f"{speaker_name}: {segment.text}\n"
|
||||
|
||||
# Add the current user query
|
||||
conversation_history += f"User: {user_text}\nAssistant:"
|
||||
|
||||
# Generate response
|
||||
inputs = llm_tokenizer(conversation_history, return_tensors="pt").to(device)
|
||||
output = llm_model.generate(
|
||||
inputs.input_ids,
|
||||
max_new_tokens=150,
|
||||
temperature=0.7,
|
||||
top_p=0.9,
|
||||
do_sample=True
|
||||
)
|
||||
|
||||
response = llm_tokenizer.decode(output[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
|
||||
return response.strip()
|
||||
|
||||
def generate_audio_response(text, conversation_segments):
|
||||
"""Generate audio response using CSM"""
|
||||
# Use the last few conversation segments as context
|
||||
context_segments = conversation_segments[-4:] if len(conversation_segments) > 4 else conversation_segments
|
||||
|
||||
# Generate audio for bot response
|
||||
audio = csm_generator.generate(
|
||||
text=text,
|
||||
speaker=1, # Bot is speaker 1
|
||||
context=context_segments,
|
||||
max_audio_length_ms=10000, # 10 seconds max
|
||||
temperature=0.9,
|
||||
topk=50
|
||||
)
|
||||
|
||||
return audio
|
||||
|
||||
if __name__ == '__main__':
|
||||
# Ensure the existing index.html file is in the correct location
|
||||
if not os.path.exists('templates'):
|
||||
os.makedirs('templates')
|
||||
|
||||
if os.path.exists('index.html') and not os.path.exists('templates/index.html'):
|
||||
os.rename('index.html', 'templates/index.html')
|
||||
|
||||
socketio.run(app, host='0.0.0.0', port=5000, debug=False)
|
||||
Reference in New Issue
Block a user