Demo Update 22
This commit is contained in:
@@ -3,7 +3,7 @@
|
|||||||
<head>
|
<head>
|
||||||
<meta charset="UTF-8">
|
<meta charset="UTF-8">
|
||||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||||
<title>Live Voice Assistant with CSM</title>
|
<title>Real-Time Voice Assistant</title>
|
||||||
<script src="https://cdn.socket.io/4.6.0/socket.io.min.js"></script>
|
<script src="https://cdn.socket.io/4.6.0/socket.io.min.js"></script>
|
||||||
<style>
|
<style>
|
||||||
body {
|
body {
|
||||||
@@ -89,33 +89,39 @@
|
|||||||
transition: all 0.3s ease;
|
transition: all 0.3s ease;
|
||||||
}
|
}
|
||||||
|
|
||||||
#talkButton {
|
#micButton {
|
||||||
background-color: #4CAF50;
|
background-color: #4CAF50;
|
||||||
color: white;
|
color: white;
|
||||||
width: 200px;
|
width: 200px;
|
||||||
box-shadow: 0 4px 8px rgba(76, 175, 80, 0.3);
|
box-shadow: 0 4px 8px rgba(76, 175, 80, 0.3);
|
||||||
}
|
}
|
||||||
|
|
||||||
#talkButton:hover {
|
#micButton:hover {
|
||||||
background-color: #45a049;
|
background-color: #45a049;
|
||||||
transform: translateY(-2px);
|
transform: translateY(-2px);
|
||||||
}
|
}
|
||||||
|
|
||||||
#talkButton.recording {
|
#micButton.listening {
|
||||||
background-color: #f44336;
|
background-color: #4CAF50;
|
||||||
|
box-shadow: 0 0 0 rgba(76, 175, 80, 0.4);
|
||||||
|
animation: pulse 1.5s infinite;
|
||||||
|
}
|
||||||
|
|
||||||
|
#micButton.speaking {
|
||||||
|
background-color: #f44336;
|
||||||
|
box-shadow: 0 0 0 rgba(244, 67, 54, 0.4);
|
||||||
animation: pulse 1.5s infinite;
|
animation: pulse 1.5s infinite;
|
||||||
box-shadow: 0 4px 8px rgba(244, 67, 54, 0.3);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@keyframes pulse {
|
@keyframes pulse {
|
||||||
0% {
|
0% {
|
||||||
transform: scale(1);
|
box-shadow: 0 0 0 0 rgba(76, 175, 80, 0.4);
|
||||||
}
|
}
|
||||||
50% {
|
70% {
|
||||||
transform: scale(1.05);
|
box-shadow: 0 0 0 15px rgba(76, 175, 80, 0);
|
||||||
}
|
}
|
||||||
100% {
|
100% {
|
||||||
transform: scale(1);
|
box-shadow: 0 0 0 0 rgba(76, 175, 80, 0);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -126,45 +132,24 @@
|
|||||||
color: #657786;
|
color: #657786;
|
||||||
}
|
}
|
||||||
|
|
||||||
.hidden {
|
|
||||||
display: none;
|
|
||||||
}
|
|
||||||
|
|
||||||
.transcription-info {
|
|
||||||
font-size: 0.8em;
|
|
||||||
color: #888;
|
|
||||||
margin-top: 4px;
|
|
||||||
text-align: right;
|
|
||||||
}
|
|
||||||
|
|
||||||
.text-only-indicator {
|
|
||||||
font-size: 0.8em;
|
|
||||||
color: #e74c3c;
|
|
||||||
margin-top: 4px;
|
|
||||||
font-style: italic;
|
|
||||||
}
|
|
||||||
|
|
||||||
.status-message {
|
|
||||||
text-align: center;
|
|
||||||
padding: 8px;
|
|
||||||
margin: 10px 0;
|
|
||||||
background-color: #f8f9fa;
|
|
||||||
border-radius: 5px;
|
|
||||||
color: #666;
|
|
||||||
font-size: 0.9em;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Audio visualizer styles */
|
|
||||||
.visualizer-container {
|
.visualizer-container {
|
||||||
width: 100%;
|
width: 100%;
|
||||||
height: 120px;
|
height: 100px;
|
||||||
margin: 15px 0;
|
margin: 15px 0;
|
||||||
border-radius: 10px;
|
border-radius: 10px;
|
||||||
overflow: hidden;
|
overflow: hidden;
|
||||||
background-color: #000;
|
background-color: #1a1a1a;
|
||||||
position: relative;
|
position: relative;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
.visualizer-container.user {
|
||||||
|
border: 2px solid #4CAF50;
|
||||||
|
}
|
||||||
|
|
||||||
|
.visualizer-container.ai {
|
||||||
|
border: 2px solid #2196F3;
|
||||||
|
}
|
||||||
|
|
||||||
#visualizer {
|
#visualizer {
|
||||||
width: 100%;
|
width: 100%;
|
||||||
height: 100%;
|
height: 100%;
|
||||||
@@ -176,122 +161,59 @@
|
|||||||
top: 10px;
|
top: 10px;
|
||||||
left: 10px;
|
left: 10px;
|
||||||
color: white;
|
color: white;
|
||||||
font-size: 0.8em;
|
font-size: 0.9em;
|
||||||
background-color: rgba(0, 0, 0, 0.5);
|
background-color: rgba(0, 0, 0, 0.5);
|
||||||
padding: 4px 8px;
|
padding: 4px 8px;
|
||||||
border-radius: 4px;
|
border-radius: 4px;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* Real-time transcription */
|
.speech-indicator {
|
||||||
.live-transcription {
|
|
||||||
position: absolute;
|
|
||||||
bottom: 10px;
|
|
||||||
left: 10px;
|
|
||||||
right: 10px;
|
|
||||||
color: white;
|
|
||||||
font-size: 0.9em;
|
|
||||||
background-color: rgba(0, 0, 0, 0.5);
|
|
||||||
padding: 8px;
|
|
||||||
border-radius: 4px;
|
|
||||||
text-align: center;
|
|
||||||
max-height: 60px;
|
|
||||||
overflow-y: auto;
|
|
||||||
font-style: italic;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Wave animation for active speaker */
|
|
||||||
.speaking-wave {
|
|
||||||
display: inline-block;
|
display: inline-block;
|
||||||
margin-left: 5px;
|
width: 10px;
|
||||||
|
height: 10px;
|
||||||
|
border-radius: 50%;
|
||||||
|
margin-right: 5px;
|
||||||
vertical-align: middle;
|
vertical-align: middle;
|
||||||
}
|
}
|
||||||
|
|
||||||
.speaking-wave span {
|
.user-speaking {
|
||||||
display: inline-block;
|
|
||||||
width: 3px;
|
|
||||||
height: 12px;
|
|
||||||
margin: 0 1px;
|
|
||||||
background-color: currentColor;
|
|
||||||
border-radius: 1px;
|
|
||||||
animation: speakingWave 1s infinite ease-in-out;
|
|
||||||
}
|
|
||||||
|
|
||||||
.speaking-wave span:nth-child(2) {
|
|
||||||
animation-delay: 0.1s;
|
|
||||||
}
|
|
||||||
|
|
||||||
.speaking-wave span:nth-child(3) {
|
|
||||||
animation-delay: 0.2s;
|
|
||||||
}
|
|
||||||
|
|
||||||
.speaking-wave span:nth-child(4) {
|
|
||||||
animation-delay: 0.3s;
|
|
||||||
}
|
|
||||||
|
|
||||||
@keyframes speakingWave {
|
|
||||||
0%, 100% {
|
|
||||||
height: 4px;
|
|
||||||
}
|
|
||||||
50% {
|
|
||||||
height: 12px;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Modern switch for visualizer toggle */
|
|
||||||
.switch-container {
|
|
||||||
display: flex;
|
|
||||||
align-items: center;
|
|
||||||
justify-content: center;
|
|
||||||
margin-bottom: 10px;
|
|
||||||
}
|
|
||||||
|
|
||||||
.switch {
|
|
||||||
position: relative;
|
|
||||||
display: inline-block;
|
|
||||||
width: 50px;
|
|
||||||
height: 24px;
|
|
||||||
margin-left: 10px;
|
|
||||||
}
|
|
||||||
|
|
||||||
.switch input {
|
|
||||||
opacity: 0;
|
|
||||||
width: 0;
|
|
||||||
height: 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
.slider {
|
|
||||||
position: absolute;
|
|
||||||
cursor: pointer;
|
|
||||||
top: 0;
|
|
||||||
left: 0;
|
|
||||||
right: 0;
|
|
||||||
bottom: 0;
|
|
||||||
background-color: #ccc;
|
|
||||||
transition: .4s;
|
|
||||||
border-radius: 24px;
|
|
||||||
}
|
|
||||||
|
|
||||||
.slider:before {
|
|
||||||
position: absolute;
|
|
||||||
content: "";
|
|
||||||
height: 16px;
|
|
||||||
width: 16px;
|
|
||||||
left: 4px;
|
|
||||||
bottom: 4px;
|
|
||||||
background-color: white;
|
|
||||||
transition: .4s;
|
|
||||||
border-radius: 50%;
|
|
||||||
}
|
|
||||||
|
|
||||||
input:checked + .slider {
|
|
||||||
background-color: #4CAF50;
|
background-color: #4CAF50;
|
||||||
|
animation: blink 1s infinite;
|
||||||
}
|
}
|
||||||
|
|
||||||
input:checked + .slider:before {
|
.ai-speaking {
|
||||||
transform: translateX(26px);
|
background-color: #2196F3;
|
||||||
|
animation: blink 1s infinite;
|
||||||
|
}
|
||||||
|
|
||||||
|
@keyframes blink {
|
||||||
|
0%, 100% { opacity: 1; }
|
||||||
|
50% { opacity: 0.4; }
|
||||||
|
}
|
||||||
|
|
||||||
|
.connection-status {
|
||||||
|
padding: 6px 10px;
|
||||||
|
border-radius: 4px;
|
||||||
|
font-size: 0.8em;
|
||||||
|
margin-top: 10px;
|
||||||
|
text-align: center;
|
||||||
|
}
|
||||||
|
|
||||||
|
.connection-status.connected {
|
||||||
|
background-color: #d4edda;
|
||||||
|
color: #155724;
|
||||||
|
}
|
||||||
|
|
||||||
|
.connection-status.connecting {
|
||||||
|
background-color: #fff3cd;
|
||||||
|
color: #856404;
|
||||||
|
}
|
||||||
|
|
||||||
|
.connection-status.disconnected {
|
||||||
|
background-color: #f8d7da;
|
||||||
|
color: #721c24;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* Toast notification for feedback */
|
|
||||||
.toast {
|
.toast {
|
||||||
position: fixed;
|
position: fixed;
|
||||||
bottom: 20px;
|
bottom: 20px;
|
||||||
@@ -328,38 +250,27 @@
|
|||||||
</style>
|
</style>
|
||||||
</head>
|
</head>
|
||||||
<body>
|
<body>
|
||||||
<h1>Live Voice Assistant with CSM</h1>
|
<h1>Real-Time Voice Assistant</h1>
|
||||||
<div id="conversation"></div>
|
<div id="conversation"></div>
|
||||||
|
|
||||||
<div class="switch-container">
|
<div class="visualizer-container user" id="visualizerContainer">
|
||||||
<span>Audio Visualizer</span>
|
|
||||||
<label class="switch">
|
|
||||||
<input type="checkbox" id="visualizerToggle" checked>
|
|
||||||
<span class="slider"></span>
|
|
||||||
</label>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<div class="visualizer-container" id="visualizerContainer">
|
|
||||||
<canvas id="visualizer"></canvas>
|
<canvas id="visualizer"></canvas>
|
||||||
<div class="visualizer-label" id="visualizerLabel">Listening...</div>
|
<div class="visualizer-label" id="visualizerLabel">Listening...</div>
|
||||||
<div class="live-transcription" id="liveTranscription"></div>
|
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<div id="controls">
|
<div id="controls">
|
||||||
<button id="talkButton">Press to Talk</button>
|
<button id="micButton">Press to Talk</button>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<div id="status">Connecting to server...</div>
|
<div id="status">Connecting to server...</div>
|
||||||
|
|
||||||
<script>
|
<script>
|
||||||
const socket = io();
|
const socket = io();
|
||||||
const talkButton = document.getElementById('talkButton');
|
const micButton = document.getElementById('micButton');
|
||||||
const conversation = document.getElementById('conversation');
|
const conversation = document.getElementById('conversation');
|
||||||
const status = document.getElementById('status');
|
const status = document.getElementById('status');
|
||||||
const visualizerToggle = document.getElementById('visualizerToggle');
|
|
||||||
const visualizerContainer = document.getElementById('visualizerContainer');
|
const visualizerContainer = document.getElementById('visualizerContainer');
|
||||||
const visualizerLabel = document.getElementById('visualizerLabel');
|
const visualizerLabel = document.getElementById('visualizerLabel');
|
||||||
const liveTranscription = document.getElementById('liveTranscription');
|
|
||||||
const canvas = document.getElementById('visualizer');
|
const canvas = document.getElementById('visualizer');
|
||||||
const canvasCtx = canvas.getContext('2d');
|
const canvasCtx = canvas.getContext('2d');
|
||||||
|
|
||||||
@@ -386,19 +297,6 @@
|
|||||||
canvas.height = visualizerContainer.offsetHeight;
|
canvas.height = visualizerContainer.offsetHeight;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle visualizer toggle
|
|
||||||
visualizerToggle.addEventListener('change', function() {
|
|
||||||
visualizerActive = this.checked;
|
|
||||||
visualizerContainer.style.display = visualizerActive ? 'block' : 'none';
|
|
||||||
|
|
||||||
if (!visualizerActive && visualizerAnimationId) {
|
|
||||||
cancelAnimationFrame(visualizerAnimationId);
|
|
||||||
visualizerAnimationId = null;
|
|
||||||
} else if (visualizerActive && audioAnalyser) {
|
|
||||||
drawVisualizer();
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
// Connect to server
|
// Connect to server
|
||||||
socket.on('connect', () => {
|
socket.on('connect', () => {
|
||||||
status.textContent = 'Connected to server';
|
status.textContent = 'Connected to server';
|
||||||
@@ -491,8 +389,8 @@
|
|||||||
setupScriptProcessor(stream);
|
setupScriptProcessor(stream);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Setup talk button
|
// Setup mic button
|
||||||
talkButton.addEventListener('click', toggleTalking);
|
micButton.addEventListener('click', toggleTalking);
|
||||||
|
|
||||||
// Setup keyboard shortcuts
|
// Setup keyboard shortcuts
|
||||||
document.addEventListener('keydown', (e) => {
|
document.addEventListener('keydown', (e) => {
|
||||||
@@ -618,8 +516,8 @@
|
|||||||
if (!sessionActive || isAITalking) return;
|
if (!sessionActive || isAITalking) return;
|
||||||
|
|
||||||
isStreaming = true;
|
isStreaming = true;
|
||||||
talkButton.classList.add('recording');
|
micButton.classList.add('listening');
|
||||||
talkButton.textContent = 'Release to Stop';
|
micButton.textContent = 'Release to Stop';
|
||||||
status.textContent = 'Listening...';
|
status.textContent = 'Listening...';
|
||||||
visualizerLabel.textContent = 'You are speaking...';
|
visualizerLabel.textContent = 'You are speaking...';
|
||||||
|
|
||||||
@@ -630,10 +528,6 @@
|
|||||||
|
|
||||||
// Tell server we're starting to speak
|
// Tell server we're starting to speak
|
||||||
socket.emit('start_speaking');
|
socket.emit('start_speaking');
|
||||||
|
|
||||||
// Clear previous transcriptions
|
|
||||||
liveTranscription.textContent = '';
|
|
||||||
liveTranscription.classList.remove('hidden');
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Stop talking to the assistant
|
// Stop talking to the assistant
|
||||||
@@ -641,15 +535,12 @@
|
|||||||
if (!isStreaming) return;
|
if (!isStreaming) return;
|
||||||
|
|
||||||
isStreaming = false;
|
isStreaming = false;
|
||||||
talkButton.classList.remove('recording');
|
micButton.classList.remove('listening');
|
||||||
talkButton.textContent = 'Press to Talk';
|
micButton.textContent = 'Press to Talk';
|
||||||
status.textContent = 'Processing...';
|
status.textContent = 'Processing...';
|
||||||
|
|
||||||
// Tell server we're done speaking
|
// Tell server we're done speaking
|
||||||
socket.emit('stop_speaking');
|
socket.emit('stop_speaking');
|
||||||
|
|
||||||
// Hide live transcription temporarily
|
|
||||||
liveTranscription.classList.add('hidden');
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Send audio chunk to server
|
// Send audio chunk to server
|
||||||
@@ -702,8 +593,7 @@
|
|||||||
|
|
||||||
// Handle real-time transcription
|
// Handle real-time transcription
|
||||||
socket.on('live_transcription', (data) => {
|
socket.on('live_transcription', (data) => {
|
||||||
liveTranscription.textContent = data.text || '...';
|
visualizerLabel.textContent = data.text || '...';
|
||||||
liveTranscription.classList.remove('hidden');
|
|
||||||
});
|
});
|
||||||
|
|
||||||
// Handle final transcription
|
// Handle final transcription
|
||||||
@@ -749,8 +639,8 @@
|
|||||||
speakingWave.remove();
|
speakingWave.remove();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Re-enable talk button if it was disabled
|
// Re-enable mic button if it was disabled
|
||||||
talkButton.disabled = false;
|
micButton.disabled = false;
|
||||||
});
|
});
|
||||||
|
|
||||||
// Legacy handler for text-only responses
|
// Legacy handler for text-only responses
|
||||||
|
|||||||
@@ -8,14 +8,14 @@ import numpy as np
|
|||||||
from flask import Flask, render_template, request
|
from flask import Flask, render_template, request
|
||||||
from flask_socketio import SocketIO, emit
|
from flask_socketio import SocketIO, emit
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
from collections import deque
|
import threading
|
||||||
|
import queue
|
||||||
import requests
|
import requests
|
||||||
import huggingface_hub
|
import huggingface_hub
|
||||||
from generator import load_csm_1b, Segment
|
from generator import load_csm_1b, Segment
|
||||||
import threading
|
from collections import deque
|
||||||
import queue
|
|
||||||
import asyncio
|
|
||||||
import json
|
import json
|
||||||
|
import webrtcvad # For voice activity detection
|
||||||
|
|
||||||
# Configure environment with longer timeouts
|
# Configure environment with longer timeouts
|
||||||
os.environ["HF_HUB_DOWNLOAD_TIMEOUT"] = "600" # 10 minutes timeout for downloads
|
os.environ["HF_HUB_DOWNLOAD_TIMEOUT"] = "600" # 10 minutes timeout for downloads
|
||||||
@@ -28,7 +28,7 @@ app = Flask(__name__)
|
|||||||
app.config['SECRET_KEY'] = 'your-secret-key'
|
app.config['SECRET_KEY'] = 'your-secret-key'
|
||||||
socketio = SocketIO(app, cors_allowed_origins="*", async_mode='eventlet')
|
socketio = SocketIO(app, cors_allowed_origins="*", async_mode='eventlet')
|
||||||
|
|
||||||
# Explicitly check for CUDA and print more detailed info
|
# Explicitly check for CUDA and print detailed info
|
||||||
print("\n=== CUDA Information ===")
|
print("\n=== CUDA Information ===")
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
print(f"CUDA is available")
|
print(f"CUDA is available")
|
||||||
@@ -47,18 +47,9 @@ try:
|
|||||||
except:
|
except:
|
||||||
print("cuDNN is not available (libcudnn_ops_infer.so.8 not found)")
|
print("cuDNN is not available (libcudnn_ops_infer.so.8 not found)")
|
||||||
|
|
||||||
# Check for other compute platforms
|
# Determine compute device
|
||||||
if torch.backends.mps.is_available():
|
|
||||||
print("MPS (Apple Silicon) is available")
|
|
||||||
else:
|
|
||||||
print("MPS is not available")
|
|
||||||
print("========================\n")
|
|
||||||
|
|
||||||
# Check for CUDA availability and handle potential CUDA/cuDNN issues
|
|
||||||
try:
|
try:
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
# Try to initialize CUDA to check if libraries are properly loaded
|
|
||||||
_ = torch.zeros(1).cuda()
|
|
||||||
device = "cuda"
|
device = "cuda"
|
||||||
whisper_compute_type = "float16"
|
whisper_compute_type = "float16"
|
||||||
print("🟢 CUDA is available and initialized successfully")
|
print("🟢 CUDA is available and initialized successfully")
|
||||||
@@ -83,14 +74,42 @@ whisper_model = None
|
|||||||
csm_generator = None
|
csm_generator = None
|
||||||
llm_model = None
|
llm_model = None
|
||||||
llm_tokenizer = None
|
llm_tokenizer = None
|
||||||
|
vad = None
|
||||||
|
|
||||||
|
# Constants
|
||||||
|
SAMPLE_RATE = 16000 # For VAD
|
||||||
|
VAD_FRAME_SIZE = 480 # 30ms at 16kHz for VAD
|
||||||
|
VAD_MODE = 3 # Aggressive mode for better results
|
||||||
|
AUDIO_CHUNK_SIZE = 2400 # 100ms chunks when streaming AI voice
|
||||||
|
|
||||||
|
# Audio sample rates
|
||||||
|
CLIENT_SAMPLE_RATE = 44100 # Browser WebAudio default
|
||||||
|
WHISPER_SAMPLE_RATE = 16000 # Whisper expects 16kHz
|
||||||
|
|
||||||
|
# Session data structures
|
||||||
|
user_sessions = {} # session_id -> complete session data
|
||||||
|
|
||||||
|
# WebRTC ICE servers (STUN/TURN servers for NAT traversal)
|
||||||
|
ICE_SERVERS = [
|
||||||
|
{"urls": "stun:stun.l.google.com:19302"},
|
||||||
|
{"urls": "stun:stun1.l.google.com:19302"}
|
||||||
|
]
|
||||||
|
|
||||||
def load_models():
|
def load_models():
|
||||||
global whisper_model, csm_generator, llm_model, llm_tokenizer
|
"""Load all necessary models"""
|
||||||
|
global whisper_model, csm_generator, llm_model, llm_tokenizer, vad
|
||||||
|
|
||||||
|
# Initialize Voice Activity Detector
|
||||||
|
try:
|
||||||
|
vad = webrtcvad.Vad(VAD_MODE)
|
||||||
|
print("Voice Activity Detector initialized")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error initializing VAD: {e}")
|
||||||
|
vad = None
|
||||||
|
|
||||||
# Initialize Faster-Whisper for transcription
|
# Initialize Faster-Whisper for transcription
|
||||||
try:
|
try:
|
||||||
print("Loading Whisper model...")
|
print("Loading Whisper model...")
|
||||||
# Import here to avoid immediate import errors if package is missing
|
|
||||||
from faster_whisper import WhisperModel
|
from faster_whisper import WhisperModel
|
||||||
whisper_model = WhisperModel("base", device=device, compute_type=whisper_compute_type, download_root="./models/whisper")
|
whisper_model = WhisperModel("base", device=device, compute_type=whisper_compute_type, download_root="./models/whisper")
|
||||||
print("Whisper model loaded successfully")
|
print("Whisper model loaded successfully")
|
||||||
@@ -110,9 +129,8 @@ def load_models():
|
|||||||
# Initialize Llama 3.2 model for response generation
|
# Initialize Llama 3.2 model for response generation
|
||||||
try:
|
try:
|
||||||
print("Loading Llama 3.2 model...")
|
print("Loading Llama 3.2 model...")
|
||||||
llm_model_id = "meta-llama/Llama-3.2-1B" # Choose appropriate size based on resources
|
llm_model_id = "meta-llama/Llama-3.2-1B"
|
||||||
llm_tokenizer = AutoTokenizer.from_pretrained(llm_model_id, cache_dir="./models/llama")
|
llm_tokenizer = AutoTokenizer.from_pretrained(llm_model_id, cache_dir="./models/llama")
|
||||||
# Use the right data type based on device
|
|
||||||
dtype = torch.bfloat16 if device != "cpu" else torch.float32
|
dtype = torch.bfloat16 if device != "cpu" else torch.float32
|
||||||
llm_model = AutoModelForCausalLM.from_pretrained(
|
llm_model = AutoModelForCausalLM.from_pretrained(
|
||||||
llm_model_id,
|
llm_model_id,
|
||||||
@@ -126,247 +144,339 @@ def load_models():
|
|||||||
print(f"Error loading Llama 3.2 model: {e}")
|
print(f"Error loading Llama 3.2 model: {e}")
|
||||||
print("Will use a fallback response generation method")
|
print("Will use a fallback response generation method")
|
||||||
|
|
||||||
# Store conversation context
|
|
||||||
conversation_context = {} # session_id -> context
|
|
||||||
active_audio_streams = {} # session_id -> stream status
|
|
||||||
|
|
||||||
@app.route('/')
|
@app.route('/')
|
||||||
def index():
|
def index():
|
||||||
|
"""Serve the main interface"""
|
||||||
return render_template('index.html')
|
return render_template('index.html')
|
||||||
|
|
||||||
@socketio.on('connect')
|
@socketio.on('connect')
|
||||||
def handle_connect():
|
def handle_connect():
|
||||||
print(f"Client connected: {request.sid}")
|
"""Handle new client connection"""
|
||||||
conversation_context[request.sid] = {
|
session_id = request.sid
|
||||||
|
print(f"Client connected: {session_id}")
|
||||||
|
|
||||||
|
# Initialize session data
|
||||||
|
user_sessions[session_id] = {
|
||||||
|
# Conversation context
|
||||||
'segments': [],
|
'segments': [],
|
||||||
'speakers': [0, 1], # 0 = user, 1 = bot
|
'conversation_history': [],
|
||||||
'audio_buffer': deque(maxlen=10), # Store recent audio chunks
|
'is_turn_active': False,
|
||||||
'is_speaking': False,
|
|
||||||
'last_activity': time.time(),
|
# Audio buffers and state
|
||||||
'active_session': True,
|
'vad_buffer': deque(maxlen=30), # ~1s of audio at 30fps
|
||||||
'transcription_buffer': [] # For real-time transcription
|
'audio_buffer': bytearray(),
|
||||||
|
'is_user_speaking': False,
|
||||||
|
'last_vad_active': time.time(),
|
||||||
|
'silence_duration': 0,
|
||||||
|
'speech_frames': 0,
|
||||||
|
|
||||||
|
# AI state
|
||||||
|
'is_ai_speaking': False,
|
||||||
|
'should_interrupt_ai': False,
|
||||||
|
'ai_stream_queue': queue.Queue(),
|
||||||
|
|
||||||
|
# WebRTC status
|
||||||
|
'webrtc_connected': False,
|
||||||
|
'webrtc_peer_id': None,
|
||||||
|
|
||||||
|
# Processing flags
|
||||||
|
'is_processing': False,
|
||||||
|
'pending_user_audio': None
|
||||||
}
|
}
|
||||||
emit('ready', {
|
|
||||||
'message': 'Connection established',
|
# Send config to client
|
||||||
'sample_rate': getattr(csm_generator, 'sample_rate', 24000) if csm_generator else 24000
|
emit('session_ready', {
|
||||||
|
'whisper_available': whisper_model is not None,
|
||||||
|
'csm_available': csm_generator is not None,
|
||||||
|
'llm_available': llm_model is not None,
|
||||||
|
'client_sample_rate': CLIENT_SAMPLE_RATE,
|
||||||
|
'server_sample_rate': getattr(csm_generator, 'sample_rate', 24000) if csm_generator else 24000,
|
||||||
|
'ice_servers': ICE_SERVERS
|
||||||
})
|
})
|
||||||
|
|
||||||
@socketio.on('disconnect')
|
@socketio.on('disconnect')
|
||||||
def handle_disconnect():
|
def handle_disconnect():
|
||||||
print(f"Client disconnected: {request.sid}")
|
"""Handle client disconnection"""
|
||||||
session_id = request.sid
|
session_id = request.sid
|
||||||
|
print(f"Client disconnected: {session_id}")
|
||||||
|
|
||||||
# Clean up resources
|
# Clean up resources
|
||||||
if session_id in conversation_context:
|
if session_id in user_sessions:
|
||||||
conversation_context[session_id]['active_session'] = False
|
# Signal any running threads to stop
|
||||||
del conversation_context[session_id]
|
user_sessions[session_id]['should_interrupt_ai'] = True
|
||||||
|
|
||||||
if session_id in active_audio_streams:
|
# Clean up resources
|
||||||
active_audio_streams[session_id]['active'] = False
|
del user_sessions[session_id]
|
||||||
del active_audio_streams[session_id]
|
|
||||||
|
@socketio.on('webrtc_signal')
|
||||||
|
def handle_webrtc_signal(data):
|
||||||
|
"""Handle WebRTC signaling for P2P connection establishment"""
|
||||||
|
session_id = request.sid
|
||||||
|
if session_id not in user_sessions:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Simply relay the signal to the client
|
||||||
|
# In a multi-user app, we would route this to the correct peer
|
||||||
|
emit('webrtc_signal', data)
|
||||||
|
|
||||||
|
@socketio.on('webrtc_connected')
|
||||||
|
def handle_webrtc_connected(data):
|
||||||
|
"""Client notifies that WebRTC connection is established"""
|
||||||
|
session_id = request.sid
|
||||||
|
if session_id not in user_sessions:
|
||||||
|
return
|
||||||
|
|
||||||
|
user_sessions[session_id]['webrtc_connected'] = True
|
||||||
|
print(f"WebRTC connected for session {session_id}")
|
||||||
|
emit('ready_for_speech', {'message': 'Ready to start conversation'})
|
||||||
|
|
||||||
@socketio.on('audio_stream')
|
@socketio.on('audio_stream')
|
||||||
def handle_audio_stream(data):
|
def handle_audio_stream(data):
|
||||||
"""Handle incoming audio stream from client"""
|
"""Process incoming audio stream packets from client"""
|
||||||
session_id = request.sid
|
session_id = request.sid
|
||||||
|
if session_id not in user_sessions:
|
||||||
if session_id not in conversation_context:
|
|
||||||
return
|
return
|
||||||
|
|
||||||
context = conversation_context[session_id]
|
session = user_sessions[session_id]
|
||||||
context['last_activity'] = time.time()
|
|
||||||
|
|
||||||
# Process different stream events
|
|
||||||
if data.get('event') == 'start':
|
|
||||||
# Client is starting to send audio
|
|
||||||
context['is_speaking'] = True
|
|
||||||
context['audio_buffer'].clear()
|
|
||||||
context['transcription_buffer'] = []
|
|
||||||
print(f"User {session_id} started streaming audio")
|
|
||||||
|
|
||||||
# If AI was speaking, interrupt it
|
|
||||||
if session_id in active_audio_streams and active_audio_streams[session_id]['active']:
|
|
||||||
active_audio_streams[session_id]['active'] = False
|
|
||||||
emit('ai_stream_interrupt', {}, room=session_id)
|
|
||||||
|
|
||||||
elif data.get('event') == 'data':
|
|
||||||
# Audio data received
|
|
||||||
if not context['is_speaking']:
|
|
||||||
return
|
|
||||||
|
|
||||||
# Decode audio chunk
|
|
||||||
try:
|
|
||||||
audio_data = base64.b64decode(data.get('audio', ''))
|
|
||||||
if not audio_data:
|
|
||||||
return
|
|
||||||
|
|
||||||
audio_numpy = np.frombuffer(audio_data, dtype=np.float32)
|
|
||||||
|
|
||||||
# Apply a simple noise gate
|
|
||||||
if np.mean(np.abs(audio_numpy)) < 0.01: # Very quiet
|
|
||||||
return
|
|
||||||
|
|
||||||
audio_tensor = torch.tensor(audio_numpy)
|
|
||||||
|
|
||||||
# Add to audio buffer
|
|
||||||
context['audio_buffer'].append(audio_tensor)
|
|
||||||
|
|
||||||
# Real-time transcription (periodic)
|
|
||||||
if len(context['audio_buffer']) % 3 == 0: # Process every 3 chunks
|
|
||||||
threading.Thread(
|
|
||||||
target=process_realtime_transcription,
|
|
||||||
args=(session_id,),
|
|
||||||
daemon=True
|
|
||||||
).start()
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error processing audio chunk: {e}")
|
|
||||||
|
|
||||||
elif data.get('event') == 'end':
|
|
||||||
# Client has finished sending audio
|
|
||||||
context['is_speaking'] = False
|
|
||||||
|
|
||||||
if len(context['audio_buffer']) > 0:
|
|
||||||
# Process the complete utterance
|
|
||||||
threading.Thread(
|
|
||||||
target=process_complete_utterance,
|
|
||||||
args=(session_id,),
|
|
||||||
daemon=True
|
|
||||||
).start()
|
|
||||||
|
|
||||||
print(f"User {session_id} stopped streaming audio")
|
|
||||||
|
|
||||||
def process_realtime_transcription(session_id):
|
|
||||||
"""Process incoming audio for real-time transcription"""
|
|
||||||
if session_id not in conversation_context or not conversation_context[session_id]['active_session']:
|
|
||||||
return
|
|
||||||
|
|
||||||
context = conversation_context[session_id]
|
|
||||||
|
|
||||||
if not context['audio_buffer'] or not context['is_speaking']:
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Combine current buffer for transcription
|
# Decode audio data
|
||||||
buffer_copy = list(context['audio_buffer'])
|
audio_bytes = base64.b64decode(data.get('audio', ''))
|
||||||
if not buffer_copy:
|
if not audio_bytes or len(audio_bytes) < 2: # Need at least one sample
|
||||||
return
|
return
|
||||||
|
|
||||||
full_audio = torch.cat(buffer_copy, dim=0)
|
# Add to current audio buffer
|
||||||
|
session['audio_buffer'] += audio_bytes
|
||||||
|
|
||||||
# Save audio to temporary WAV file for transcription
|
# Check for speech using VAD
|
||||||
temp_audio_path = f"temp_rt_{session_id}.wav"
|
has_speech = detect_speech(audio_bytes, session_id)
|
||||||
|
|
||||||
|
# Handle speech state machine
|
||||||
|
if has_speech:
|
||||||
|
# Reset silence tracking when speech is detected
|
||||||
|
session['last_vad_active'] = time.time()
|
||||||
|
session['silence_duration'] = 0
|
||||||
|
session['speech_frames'] += 1
|
||||||
|
|
||||||
|
# If not already marked as speaking and we have enough speech frames
|
||||||
|
if not session['is_user_speaking'] and session['speech_frames'] >= 5:
|
||||||
|
on_speech_started(session_id)
|
||||||
|
else:
|
||||||
|
# No speech detected in this frame
|
||||||
|
if session['is_user_speaking']:
|
||||||
|
# Calculate silence duration
|
||||||
|
now = time.time()
|
||||||
|
session['silence_duration'] = now - session['last_vad_active']
|
||||||
|
|
||||||
|
# If silent for more than 0.5 seconds, end speech segment
|
||||||
|
if session['silence_duration'] > 0.8 and session['speech_frames'] > 8:
|
||||||
|
on_speech_ended(session_id)
|
||||||
|
else:
|
||||||
|
# Not speaking and no speech, just a silent frame
|
||||||
|
session['speech_frames'] = max(0, session['speech_frames'] - 1)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error processing audio stream: {e}")
|
||||||
|
|
||||||
|
def detect_speech(audio_bytes, session_id):
|
||||||
|
"""Use VAD to check if audio contains speech"""
|
||||||
|
if session_id not in user_sessions:
|
||||||
|
return False
|
||||||
|
|
||||||
|
session = user_sessions[session_id]
|
||||||
|
|
||||||
|
# Store in VAD buffer for history
|
||||||
|
session['vad_buffer'].append(audio_bytes)
|
||||||
|
|
||||||
|
if vad is None:
|
||||||
|
# Fallback to simple energy detection
|
||||||
|
audio_data = np.frombuffer(audio_bytes, dtype=np.int16)
|
||||||
|
energy = np.mean(np.abs(audio_data)) / 32768.0
|
||||||
|
return energy > 0.015 # Simple threshold
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Ensure we have the right amount of data for VAD
|
||||||
|
audio_data = np.frombuffer(audio_bytes, dtype=np.int16)
|
||||||
|
|
||||||
|
# If we have too much data, use just the right amount
|
||||||
|
if len(audio_data) >= VAD_FRAME_SIZE:
|
||||||
|
frame = audio_data[:VAD_FRAME_SIZE].tobytes()
|
||||||
|
return vad.is_speech(frame, SAMPLE_RATE)
|
||||||
|
|
||||||
|
# If too little data, accumulate in the VAD buffer and check periodically
|
||||||
|
if len(session['vad_buffer']) >= 3:
|
||||||
|
# Combine recent chunks to get enough data
|
||||||
|
combined = bytearray()
|
||||||
|
for chunk in list(session['vad_buffer'])[-3:]:
|
||||||
|
combined.extend(chunk)
|
||||||
|
|
||||||
|
# Extract the right amount of data
|
||||||
|
if len(combined) >= VAD_FRAME_SIZE:
|
||||||
|
frame = combined[:VAD_FRAME_SIZE]
|
||||||
|
return vad.is_speech(bytes(frame), SAMPLE_RATE)
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"VAD error: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def on_speech_started(session_id):
|
||||||
|
"""Handle start of user speech"""
|
||||||
|
if session_id not in user_sessions:
|
||||||
|
return
|
||||||
|
|
||||||
|
session = user_sessions[session_id]
|
||||||
|
|
||||||
|
# Reset audio buffer
|
||||||
|
session['audio_buffer'] = bytearray()
|
||||||
|
session['is_user_speaking'] = True
|
||||||
|
session['is_turn_active'] = True
|
||||||
|
|
||||||
|
# If AI is speaking, we need to interrupt it
|
||||||
|
if session['is_ai_speaking']:
|
||||||
|
session['should_interrupt_ai'] = True
|
||||||
|
emit('ai_interrupted_by_user', room=session_id)
|
||||||
|
|
||||||
|
# Notify client that we detected speech
|
||||||
|
emit('user_speech_start', room=session_id)
|
||||||
|
|
||||||
|
def on_speech_ended(session_id):
|
||||||
|
"""Handle end of user speech segment"""
|
||||||
|
if session_id not in user_sessions:
|
||||||
|
return
|
||||||
|
|
||||||
|
session = user_sessions[session_id]
|
||||||
|
|
||||||
|
# Mark as not speaking anymore
|
||||||
|
session['is_user_speaking'] = False
|
||||||
|
session['speech_frames'] = 0
|
||||||
|
|
||||||
|
# If no audio or already processing, skip
|
||||||
|
if len(session['audio_buffer']) < 4000 or session['is_processing']: # At least 250ms of audio
|
||||||
|
session['audio_buffer'] = bytearray()
|
||||||
|
return
|
||||||
|
|
||||||
|
# Mark as processing to prevent multiple processes
|
||||||
|
session['is_processing'] = True
|
||||||
|
|
||||||
|
# Create a copy of the audio buffer
|
||||||
|
audio_copy = session['audio_buffer']
|
||||||
|
session['audio_buffer'] = bytearray()
|
||||||
|
|
||||||
|
# Convert audio to the format needed for processing
|
||||||
|
try:
|
||||||
|
# Convert to float32 between -1 and 1
|
||||||
|
audio_np = np.frombuffer(audio_copy, dtype=np.int16).astype(np.float32) / 32768.0
|
||||||
|
audio_tensor = torch.from_numpy(audio_np)
|
||||||
|
|
||||||
|
# Resample to Whisper's expected sample rate if necessary
|
||||||
|
if CLIENT_SAMPLE_RATE != WHISPER_SAMPLE_RATE:
|
||||||
|
audio_tensor = torchaudio.functional.resample(
|
||||||
|
audio_tensor,
|
||||||
|
orig_freq=CLIENT_SAMPLE_RATE,
|
||||||
|
new_freq=WHISPER_SAMPLE_RATE
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save as WAV for transcription
|
||||||
|
temp_audio_path = f"temp_audio_{session_id}.wav"
|
||||||
torchaudio.save(
|
torchaudio.save(
|
||||||
temp_audio_path,
|
temp_audio_path,
|
||||||
full_audio.unsqueeze(0),
|
audio_tensor.unsqueeze(0),
|
||||||
44100 # Assuming 44.1kHz from client
|
WHISPER_SAMPLE_RATE
|
||||||
)
|
)
|
||||||
|
|
||||||
# Transcribe with Whisper if available
|
# Start transcription and response process in a thread
|
||||||
if whisper_model is not None:
|
threading.Thread(
|
||||||
segments, _ = whisper_model.transcribe(temp_audio_path, beam_size=5)
|
target=process_user_utterance,
|
||||||
text = " ".join([segment.text for segment in segments])
|
args=(session_id, temp_audio_path, audio_tensor),
|
||||||
|
daemon=True
|
||||||
|
).start()
|
||||||
|
|
||||||
|
# Notify client that processing has started
|
||||||
|
emit('processing_speech', room=session_id)
|
||||||
|
|
||||||
if text.strip():
|
|
||||||
context['transcription_buffer'].append(text)
|
|
||||||
# Send partial transcription to client
|
|
||||||
emit('partial_transcription', {'text': text}, room=session_id)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error in realtime transcription: {e}")
|
print(f"Error preparing audio: {e}")
|
||||||
finally:
|
session['is_processing'] = False
|
||||||
# Clean up
|
emit('error', {'message': f'Error processing audio: {str(e)}'}, room=session_id)
|
||||||
if os.path.exists(temp_audio_path):
|
|
||||||
os.remove(temp_audio_path)
|
|
||||||
|
|
||||||
def process_complete_utterance(session_id):
|
def process_user_utterance(session_id, audio_path, audio_tensor):
|
||||||
"""Process completed user utterance, generate response and stream audio back"""
|
"""Process user utterance, transcribe and generate response"""
|
||||||
if session_id not in conversation_context or not conversation_context[session_id]['active_session']:
|
if session_id not in user_sessions:
|
||||||
return
|
return
|
||||||
|
|
||||||
context = conversation_context[session_id]
|
session = user_sessions[session_id]
|
||||||
|
|
||||||
if not context['audio_buffer']:
|
|
||||||
return
|
|
||||||
|
|
||||||
# Combine audio chunks
|
|
||||||
full_audio = torch.cat(list(context['audio_buffer']), dim=0)
|
|
||||||
context['audio_buffer'].clear()
|
|
||||||
|
|
||||||
# Save audio to temporary WAV file for transcription
|
|
||||||
temp_audio_path = f"temp_audio_{session_id}.wav"
|
|
||||||
torchaudio.save(
|
|
||||||
temp_audio_path,
|
|
||||||
full_audio.unsqueeze(0),
|
|
||||||
44100 # Assuming 44.1kHz from client
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Try using Whisper first if available
|
# Transcribe audio
|
||||||
if whisper_model is not None:
|
if whisper_model is not None:
|
||||||
user_text = transcribe_with_whisper(temp_audio_path)
|
user_text = transcribe_with_whisper(audio_path)
|
||||||
else:
|
else:
|
||||||
# Fallback to Google's speech recognition
|
# Fallback to another transcription service
|
||||||
user_text = transcribe_with_google(temp_audio_path)
|
user_text = transcribe_fallback(audio_path)
|
||||||
|
|
||||||
if not user_text:
|
# Clean up temp file
|
||||||
print("No speech detected.")
|
if os.path.exists(audio_path):
|
||||||
emit('error', {'message': 'No speech detected. Please try again.'}, room=session_id)
|
os.remove(audio_path)
|
||||||
|
|
||||||
|
# Check if we got meaningful text
|
||||||
|
if not user_text or len(user_text.strip()) < 2:
|
||||||
|
emit('no_speech_detected', room=session_id)
|
||||||
|
session['is_processing'] = False
|
||||||
return
|
return
|
||||||
|
|
||||||
print(f"Transcribed: {user_text}")
|
print(f"Transcribed: {user_text}")
|
||||||
|
|
||||||
# Add to conversation segments
|
# Create user segment
|
||||||
user_segment = Segment(
|
user_segment = Segment(
|
||||||
text=user_text,
|
text=user_text,
|
||||||
speaker=0, # User is speaker 0
|
speaker=0, # User is speaker 0
|
||||||
audio=full_audio
|
audio=audio_tensor
|
||||||
)
|
)
|
||||||
context['segments'].append(user_segment)
|
session['segments'].append(user_segment)
|
||||||
|
|
||||||
# Generate bot response text
|
# Update conversation history
|
||||||
bot_response = generate_llm_response(user_text, context['segments'])
|
session['conversation_history'].append({
|
||||||
print(f"Bot response: {bot_response}")
|
'role': 'user',
|
||||||
|
'text': user_text
|
||||||
|
})
|
||||||
|
|
||||||
# Send transcribed text to client
|
# Send transcription to client
|
||||||
emit('transcription', {'text': user_text}, room=session_id)
|
emit('transcription', {'text': user_text}, room=session_id)
|
||||||
|
|
||||||
# Generate and stream audio response if CSM is available
|
# Generate AI response
|
||||||
|
ai_response = generate_ai_response(user_text, session_id)
|
||||||
|
|
||||||
|
# Send text response to client
|
||||||
|
emit('ai_response_text', {'text': ai_response}, room=session_id)
|
||||||
|
|
||||||
|
# Update conversation history
|
||||||
|
session['conversation_history'].append({
|
||||||
|
'role': 'assistant',
|
||||||
|
'text': ai_response
|
||||||
|
})
|
||||||
|
|
||||||
|
# Generate voice response if CSM is available
|
||||||
if csm_generator is not None:
|
if csm_generator is not None:
|
||||||
# Create stream state object
|
session['is_ai_speaking'] = True
|
||||||
active_audio_streams[session_id] = {
|
session['should_interrupt_ai'] = False
|
||||||
'active': True,
|
|
||||||
'text': bot_response
|
|
||||||
}
|
|
||||||
|
|
||||||
# Send initial response to prepare client
|
# Begin streaming audio response
|
||||||
emit('ai_stream_start', {
|
|
||||||
'text': bot_response
|
|
||||||
}, room=session_id)
|
|
||||||
|
|
||||||
# Start audio generation in a separate thread
|
|
||||||
threading.Thread(
|
threading.Thread(
|
||||||
target=generate_and_stream_audio_realtime,
|
target=stream_ai_response,
|
||||||
args=(bot_response, context['segments'], session_id),
|
args=(ai_response, session_id),
|
||||||
daemon=True
|
daemon=True
|
||||||
).start()
|
).start()
|
||||||
else:
|
|
||||||
# Send text-only response if audio generation isn't available
|
|
||||||
emit('text_response', {'text': bot_response}, room=session_id)
|
|
||||||
|
|
||||||
# Add text-only bot response to conversation history
|
|
||||||
bot_segment = Segment(
|
|
||||||
text=bot_response,
|
|
||||||
speaker=1, # Bot is speaker 1
|
|
||||||
audio=torch.zeros(1) # Placeholder empty audio
|
|
||||||
)
|
|
||||||
context['segments'].append(bot_segment)
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error processing speech: {e}")
|
print(f"Error processing utterance: {e}")
|
||||||
emit('error', {'message': f'Error processing speech: {str(e)}'}, room=session_id)
|
emit('error', {'message': f'Error: {str(e)}'}, room=session_id)
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
# Cleanup temp file
|
# Clear processing flag
|
||||||
if os.path.exists(temp_audio_path):
|
if session_id in user_sessions:
|
||||||
os.remove(temp_audio_path)
|
session['is_processing'] = False
|
||||||
|
|
||||||
def transcribe_with_whisper(audio_path):
|
def transcribe_with_whisper(audio_path):
|
||||||
"""Transcribe audio using Faster-Whisper"""
|
"""Transcribe audio using Faster-Whisper"""
|
||||||
@@ -375,49 +485,58 @@ def transcribe_with_whisper(audio_path):
|
|||||||
# Collect all text from segments
|
# Collect all text from segments
|
||||||
user_text = ""
|
user_text = ""
|
||||||
for segment in segments:
|
for segment in segments:
|
||||||
segment_text = segment.text.strip()
|
user_text += segment.text.strip() + " "
|
||||||
print(f"[{segment.start:.2f}s -> {segment.end:.2f}s] {segment_text}")
|
|
||||||
user_text += segment_text + " "
|
|
||||||
|
|
||||||
print(f"Transcribed text: {user_text.strip()}")
|
|
||||||
|
|
||||||
return user_text.strip()
|
return user_text.strip()
|
||||||
|
|
||||||
def transcribe_with_google(audio_path):
|
def transcribe_fallback(audio_path):
|
||||||
"""Fallback transcription using Google's speech recognition"""
|
"""Fallback transcription using Google's speech recognition"""
|
||||||
import speech_recognition as sr
|
try:
|
||||||
recognizer = sr.Recognizer()
|
import speech_recognition as sr
|
||||||
|
recognizer = sr.Recognizer()
|
||||||
|
|
||||||
with sr.AudioFile(audio_path) as source:
|
with sr.AudioFile(audio_path) as source:
|
||||||
audio = recognizer.record(source)
|
audio = recognizer.record(source)
|
||||||
try:
|
try:
|
||||||
text = recognizer.recognize_google(audio)
|
text = recognizer.recognize_google(audio)
|
||||||
return text
|
return text
|
||||||
except sr.UnknownValueError:
|
except sr.UnknownValueError:
|
||||||
return ""
|
return ""
|
||||||
except sr.RequestError:
|
except sr.RequestError:
|
||||||
# If Google API fails, try a basic energy-based VAD approach
|
return "[Speech recognition service unavailable]"
|
||||||
# This is a very basic fallback and won't give good results
|
except ImportError:
|
||||||
return "[Speech detected but transcription failed]"
|
return "[Speech recognition not available]"
|
||||||
|
|
||||||
|
def generate_ai_response(user_text, session_id):
|
||||||
|
"""Generate text response using available LLM"""
|
||||||
|
if session_id not in user_sessions:
|
||||||
|
return "I'm sorry, your session has expired."
|
||||||
|
|
||||||
|
session = user_sessions[session_id]
|
||||||
|
|
||||||
def generate_llm_response(user_text, conversation_segments):
|
|
||||||
"""Generate text response using available model"""
|
|
||||||
if llm_model is not None and llm_tokenizer is not None:
|
if llm_model is not None and llm_tokenizer is not None:
|
||||||
# Format conversation history for the LLM
|
# Format conversation history for the LLM
|
||||||
conversation_history = ""
|
prompt = "You are a helpful, friendly voice assistant. Keep your responses brief and conversational.\n\n"
|
||||||
for segment in conversation_segments[-5:]: # Use last 5 utterances for context
|
|
||||||
speaker_name = "User" if segment.speaker == 0 else "Assistant"
|
|
||||||
conversation_history += f"{speaker_name}: {segment.text}\n"
|
|
||||||
|
|
||||||
# Add the current user query
|
# Add recent conversation history (last 6 turns maximum)
|
||||||
conversation_history += f"User: {user_text}\nAssistant:"
|
for entry in session['conversation_history'][-6:]:
|
||||||
|
if entry['role'] == 'user':
|
||||||
|
prompt += f"User: {entry['text']}\n"
|
||||||
|
else:
|
||||||
|
prompt += f"Assistant: {entry['text']}\n"
|
||||||
|
|
||||||
|
# Add current query if not already in history
|
||||||
|
if not session['conversation_history'] or session['conversation_history'][-1]['role'] != 'user':
|
||||||
|
prompt += f"User: {user_text}\n"
|
||||||
|
|
||||||
|
prompt += "Assistant: "
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Generate response
|
# Generate response
|
||||||
inputs = llm_tokenizer(conversation_history, return_tensors="pt").to(device)
|
inputs = llm_tokenizer(prompt, return_tensors="pt").to(device)
|
||||||
output = llm_model.generate(
|
output = llm_model.generate(
|
||||||
inputs.input_ids,
|
inputs.input_ids,
|
||||||
max_new_tokens=150,
|
max_new_tokens=100, # Keep responses shorter for voice
|
||||||
temperature=0.7,
|
temperature=0.7,
|
||||||
top_p=0.9,
|
top_p=0.9,
|
||||||
do_sample=True
|
do_sample=True
|
||||||
@@ -425,43 +544,48 @@ def generate_llm_response(user_text, conversation_segments):
|
|||||||
|
|
||||||
response = llm_tokenizer.decode(output[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
|
response = llm_tokenizer.decode(output[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
|
||||||
return response.strip()
|
return response.strip()
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error generating response with LLM: {e}")
|
print(f"Error generating LLM response: {e}")
|
||||||
return fallback_response(user_text)
|
return fallback_response(user_text)
|
||||||
else:
|
else:
|
||||||
return fallback_response(user_text)
|
return fallback_response(user_text)
|
||||||
|
|
||||||
def fallback_response(user_text):
|
def fallback_response(user_text):
|
||||||
"""Generate a simple fallback response when LLM is not available"""
|
"""Generate simple fallback responses when LLM is unavailable"""
|
||||||
# Simple rule-based responses
|
|
||||||
user_text_lower = user_text.lower()
|
user_text_lower = user_text.lower()
|
||||||
|
|
||||||
if "hello" in user_text_lower or "hi" in user_text_lower:
|
if "hello" in user_text_lower or "hi" in user_text_lower:
|
||||||
return "Hello! I'm a simple fallback assistant. The main language model couldn't be loaded, so I have limited capabilities."
|
return "Hello! How can I help you today?"
|
||||||
|
|
||||||
elif "how are you" in user_text_lower:
|
elif "how are you" in user_text_lower:
|
||||||
return "I'm functioning within my limited capabilities. How can I assist you today?"
|
return "I'm doing well, thanks for asking! How about you?"
|
||||||
|
|
||||||
elif "thank" in user_text_lower:
|
elif "thank" in user_text_lower:
|
||||||
return "You're welcome! Let me know if there's anything else I can help with."
|
return "You're welcome! Happy to help."
|
||||||
|
|
||||||
elif "bye" in user_text_lower or "goodbye" in user_text_lower:
|
elif "bye" in user_text_lower or "goodbye" in user_text_lower:
|
||||||
return "Goodbye! Have a great day!"
|
return "Goodbye! Have a great day!"
|
||||||
|
|
||||||
elif any(q in user_text_lower for q in ["what", "who", "where", "when", "why", "how"]):
|
elif any(q in user_text_lower for q in ["what", "who", "where", "when", "why", "how"]):
|
||||||
return "I'm running in fallback mode and can't answer complex questions. Please try again when the main language model is available."
|
return "That's an interesting question. I wish I could provide a better answer in my current fallback mode."
|
||||||
|
|
||||||
else:
|
else:
|
||||||
return "I understand you said something about that. Unfortunately, I'm running in fallback mode with limited capabilities. Please try again later when the main model is available."
|
return "I see. Tell me more about that."
|
||||||
|
|
||||||
def generate_and_stream_audio_realtime(text, conversation_segments, session_id):
|
def stream_ai_response(text, session_id):
|
||||||
"""Generate audio response using CSM and stream it in real-time to client"""
|
"""Generate and stream audio response in real-time chunks"""
|
||||||
if session_id not in active_audio_streams or not active_audio_streams[session_id]['active']:
|
if session_id not in user_sessions:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
session = user_sessions[session_id]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Use the last few conversation segments as context
|
# Signal start of AI speech
|
||||||
context_segments = conversation_segments[-4:] if len(conversation_segments) > 4 else conversation_segments
|
emit('ai_speech_start', room=session_id)
|
||||||
|
|
||||||
|
# Use the last few conversation segments as context (up to 4)
|
||||||
|
context_segments = session['segments'][-4:] if len(session['segments']) > 4 else session['segments']
|
||||||
|
|
||||||
# Generate audio for bot response
|
# Generate audio for bot response
|
||||||
audio = csm_generator.generate(
|
audio = csm_generator.generate(
|
||||||
@@ -473,23 +597,26 @@ def generate_and_stream_audio_realtime(text, conversation_segments, session_id):
|
|||||||
topk=50
|
topk=50
|
||||||
)
|
)
|
||||||
|
|
||||||
# Store the full audio for conversation history
|
# Create and store bot segment
|
||||||
bot_segment = Segment(
|
bot_segment = Segment(
|
||||||
text=text,
|
text=text,
|
||||||
speaker=1, # Bot is speaker 1
|
speaker=1,
|
||||||
audio=audio
|
audio=audio
|
||||||
)
|
)
|
||||||
if session_id in conversation_context and conversation_context[session_id]['active_session']:
|
|
||||||
conversation_context[session_id]['segments'].append(bot_segment)
|
if session_id in user_sessions:
|
||||||
|
session['segments'].append(bot_segment)
|
||||||
|
|
||||||
# Stream audio in small chunks for more responsive playback
|
# Stream audio in small chunks for more responsive playback
|
||||||
chunk_size = 4800 # 200ms at 24kHz
|
chunk_size = AUDIO_CHUNK_SIZE # Size defined in constants
|
||||||
|
|
||||||
for i in range(0, len(audio), chunk_size):
|
for i in range(0, len(audio), chunk_size):
|
||||||
if session_id not in active_audio_streams or not active_audio_streams[session_id]['active']:
|
# Check if we should stop (user interrupted)
|
||||||
print("Audio streaming interrupted or session ended")
|
if session_id not in user_sessions or session['should_interrupt_ai']:
|
||||||
|
print("AI speech interrupted")
|
||||||
break
|
break
|
||||||
|
|
||||||
|
# Get next chunk
|
||||||
chunk = audio[i:i+chunk_size]
|
chunk = audio[i:i+chunk_size]
|
||||||
|
|
||||||
# Convert audio chunk to base64 for streaming
|
# Convert audio chunk to base64 for streaming
|
||||||
@@ -499,32 +626,48 @@ def generate_and_stream_audio_realtime(text, conversation_segments, session_id):
|
|||||||
audio_b64 = base64.b64encode(audio_bytes.read()).decode('utf-8')
|
audio_b64 = base64.b64encode(audio_bytes.read()).decode('utf-8')
|
||||||
|
|
||||||
# Send chunk to client
|
# Send chunk to client
|
||||||
socketio.emit('ai_stream_data', {
|
socketio.emit('ai_speech_chunk', {
|
||||||
'audio': audio_b64,
|
'audio': audio_b64,
|
||||||
'is_last': i + chunk_size >= len(audio)
|
'is_last': i + chunk_size >= len(audio)
|
||||||
}, room=session_id)
|
}, room=session_id)
|
||||||
|
|
||||||
# Simulate real-time speech by adding a small delay
|
# Small sleep for more natural pacing
|
||||||
# Remove this in production for faster response
|
time.sleep(0.06) # Slight delay for smoother playback
|
||||||
time.sleep(0.15) # Slight delay for more natural timing
|
|
||||||
|
|
||||||
# Signal end of stream
|
# Signal end of AI speech
|
||||||
if session_id in active_audio_streams and active_audio_streams[session_id]['active']:
|
if session_id in user_sessions:
|
||||||
socketio.emit('ai_stream_end', {}, room=session_id)
|
session['is_ai_speaking'] = False
|
||||||
active_audio_streams[session_id]['active'] = False
|
session['is_turn_active'] = False # End conversation turn
|
||||||
|
socketio.emit('ai_speech_end', room=session_id)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error generating or streaming audio: {e}")
|
print(f"Error streaming AI response: {e}")
|
||||||
# Send error message to client
|
if session_id in user_sessions:
|
||||||
if session_id in conversation_context and conversation_context[session_id]['active_session']:
|
session['is_ai_speaking'] = False
|
||||||
socketio.emit('error', {
|
session['is_turn_active'] = False
|
||||||
'message': f'Error generating audio: {str(e)}'
|
socketio.emit('error', {'message': f'Error generating audio: {str(e)}'}, room=session_id)
|
||||||
}, room=session_id)
|
socketio.emit('ai_speech_end', room=session_id)
|
||||||
|
|
||||||
# Signal stream end to unblock client
|
@socketio.on('interrupt_ai')
|
||||||
socketio.emit('ai_stream_end', {}, room=session_id)
|
def handle_interrupt():
|
||||||
if session_id in active_audio_streams:
|
"""Handle explicit AI interruption request from client"""
|
||||||
active_audio_streams[session_id]['active'] = False
|
session_id = request.sid
|
||||||
|
if session_id in user_sessions:
|
||||||
|
user_sessions[session_id]['should_interrupt_ai'] = True
|
||||||
|
emit('ai_interrupted', room=session_id)
|
||||||
|
|
||||||
|
@socketio.on('get_config')
|
||||||
|
def handle_get_config():
|
||||||
|
"""Send configuration to client"""
|
||||||
|
session_id = request.sid
|
||||||
|
if session_id in user_sessions:
|
||||||
|
emit('config', {
|
||||||
|
'client_sample_rate': CLIENT_SAMPLE_RATE,
|
||||||
|
'server_sample_rate': getattr(csm_generator, 'sample_rate', 24000) if csm_generator else 24000,
|
||||||
|
'whisper_available': whisper_model is not None,
|
||||||
|
'csm_available': csm_generator is not None,
|
||||||
|
'ice_servers': ICE_SERVERS
|
||||||
|
})
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
# Ensure the existing index.html file is in the correct location
|
# Ensure the existing index.html file is in the correct location
|
||||||
@@ -538,6 +681,6 @@ if __name__ == '__main__':
|
|||||||
print("Starting model loading...")
|
print("Starting model loading...")
|
||||||
load_models()
|
load_models()
|
||||||
|
|
||||||
# Start the server with eventlet for better WebSocket performance
|
# Start the server
|
||||||
print("Starting Flask SocketIO server...")
|
print("Starting Flask SocketIO server...")
|
||||||
socketio.run(app, host='0.0.0.0', port=5000, debug=False)
|
socketio.run(app, host='0.0.0.0', port=5000, debug=False)
|
||||||
Reference in New Issue
Block a user