691 lines
24 KiB
Python
691 lines
24 KiB
Python
import eventlet
|
|
eventlet.monkey_patch()
|
|
# Keep your other imports below
|
|
import os
|
|
import io
|
|
import base64
|
|
import time
|
|
import torch
|
|
import torchaudio
|
|
import numpy as np
|
|
from flask import Flask, render_template, request
|
|
from flask_socketio import SocketIO, emit
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
import threading
|
|
import queue
|
|
import requests
|
|
import huggingface_hub
|
|
from generator import load_csm_1b, Segment
|
|
from collections import deque
|
|
import json
|
|
import webrtcvad # For voice activity detection
|
|
|
|
# Configure environment with longer timeouts
|
|
os.environ["HF_HUB_DOWNLOAD_TIMEOUT"] = "600" # 10 minutes timeout for downloads
|
|
requests.adapters.DEFAULT_TIMEOUT = 60 # Increase default requests timeout
|
|
|
|
# Create a models directory for caching
|
|
os.makedirs("models", exist_ok=True)
|
|
|
|
app = Flask(__name__)
|
|
app.config['SECRET_KEY'] = 'your-secret-key'
|
|
socketio = SocketIO(app, cors_allowed_origins="*", async_mode='eventlet')
|
|
|
|
# Explicitly check for CUDA and print detailed info
|
|
print("\n=== CUDA Information ===")
|
|
if torch.cuda.is_available():
|
|
print(f"CUDA is available")
|
|
print(f"CUDA version: {torch.version.cuda}")
|
|
print(f"Number of GPUs: {torch.cuda.device_count()}")
|
|
for i in range(torch.cuda.device_count()):
|
|
print(f"GPU {i}: {torch.cuda.get_device_name(i)}")
|
|
else:
|
|
print("CUDA is not available")
|
|
|
|
# Check for cuDNN
|
|
try:
|
|
import ctypes
|
|
ctypes.CDLL("libcudnn_ops_infer.so.8")
|
|
print("cuDNN is available")
|
|
except:
|
|
print("cuDNN is not available (libcudnn_ops_infer.so.8 not found)")
|
|
|
|
# Determine compute device
|
|
try:
|
|
if torch.cuda.is_available():
|
|
device = "cuda"
|
|
whisper_compute_type = "float16"
|
|
print("🟢 CUDA is available and initialized successfully")
|
|
elif torch.backends.mps.is_available():
|
|
device = "mps"
|
|
whisper_compute_type = "float32"
|
|
print("🟢 MPS is available (Apple Silicon)")
|
|
else:
|
|
device = "cpu"
|
|
whisper_compute_type = "int8"
|
|
print("🟡 Using CPU (CUDA/MPS not available)")
|
|
except Exception as e:
|
|
print(f"🔴 Error initializing CUDA: {e}")
|
|
print("🔴 Falling back to CPU")
|
|
device = "cpu"
|
|
whisper_compute_type = "int8"
|
|
|
|
print(f"Using device: {device}")
|
|
|
|
# Initialize models with proper error handling
|
|
whisper_model = None
|
|
csm_generator = None
|
|
llm_model = None
|
|
llm_tokenizer = None
|
|
vad = None
|
|
|
|
# Constants
|
|
SAMPLE_RATE = 16000 # For VAD
|
|
VAD_FRAME_SIZE = 480 # 30ms at 16kHz for VAD
|
|
VAD_MODE = 3 # Aggressive mode for better results
|
|
AUDIO_CHUNK_SIZE = 2400 # 100ms chunks when streaming AI voice
|
|
|
|
# Audio sample rates
|
|
CLIENT_SAMPLE_RATE = 44100 # Browser WebAudio default
|
|
WHISPER_SAMPLE_RATE = 16000 # Whisper expects 16kHz
|
|
|
|
# Session data structures
|
|
user_sessions = {} # session_id -> complete session data
|
|
|
|
# WebRTC ICE servers (STUN/TURN servers for NAT traversal)
|
|
ICE_SERVERS = [
|
|
{"urls": "stun:stun.l.google.com:19302"},
|
|
{"urls": "stun:stun1.l.google.com:19302"}
|
|
]
|
|
|
|
def load_models():
|
|
"""Load all necessary models"""
|
|
global whisper_model, csm_generator, llm_model, llm_tokenizer, vad
|
|
|
|
# Initialize Voice Activity Detector
|
|
try:
|
|
vad = webrtcvad.Vad(VAD_MODE)
|
|
print("Voice Activity Detector initialized")
|
|
except Exception as e:
|
|
print(f"Error initializing VAD: {e}")
|
|
vad = None
|
|
|
|
# Initialize Faster-Whisper for transcription
|
|
try:
|
|
print("Loading Whisper model...")
|
|
from faster_whisper import WhisperModel
|
|
whisper_model = WhisperModel("base", device=device, compute_type=whisper_compute_type, download_root="./models/whisper")
|
|
print("Whisper model loaded successfully")
|
|
except Exception as e:
|
|
print(f"Error loading Whisper model: {e}")
|
|
print("Will use backup speech recognition method if available")
|
|
|
|
# Initialize CSM model for audio generation
|
|
try:
|
|
print("Loading CSM model...")
|
|
csm_generator = load_csm_1b(device=device)
|
|
print("CSM model loaded successfully")
|
|
except Exception as e:
|
|
print(f"Error loading CSM model: {e}")
|
|
print("Audio generation will not be available")
|
|
|
|
# Initialize Llama 3.2 model for response generation
|
|
try:
|
|
print("Loading Llama 3.2 model...")
|
|
llm_model_id = "meta-llama/Llama-3.2-1B"
|
|
llm_tokenizer = AutoTokenizer.from_pretrained(llm_model_id, cache_dir="./models/llama")
|
|
dtype = torch.bfloat16 if device != "cpu" else torch.float32
|
|
llm_model = AutoModelForCausalLM.from_pretrained(
|
|
llm_model_id,
|
|
torch_dtype=dtype,
|
|
device_map=device,
|
|
cache_dir="./models/llama",
|
|
low_cpu_mem_usage=True
|
|
)
|
|
print("Llama 3.2 model loaded successfully")
|
|
except Exception as e:
|
|
print(f"Error loading Llama 3.2 model: {e}")
|
|
print("Will use a fallback response generation method")
|
|
|
|
@app.route('/')
|
|
def index():
|
|
"""Serve the main interface"""
|
|
return render_template('index.html')
|
|
|
|
@app.route('/static/js/voice-chat.js')
|
|
def serve_voice_chat_js():
|
|
"""Serve the JavaScript file"""
|
|
return app.send_static_file('js/voice-chat.js')
|
|
|
|
@socketio.on('connect')
|
|
def handle_connect():
|
|
"""Handle new client connection"""
|
|
session_id = request.sid
|
|
print(f"Client connected: {session_id}")
|
|
|
|
# Initialize session data
|
|
user_sessions[session_id] = {
|
|
# Conversation context
|
|
'segments': [],
|
|
'conversation_history': [],
|
|
'is_turn_active': False,
|
|
|
|
# Audio buffers and state
|
|
'vad_buffer': deque(maxlen=30), # ~1s of audio at 30fps
|
|
'audio_buffer': bytearray(),
|
|
'is_user_speaking': False,
|
|
'last_vad_active': time.time(),
|
|
'silence_duration': 0,
|
|
'speech_frames': 0,
|
|
|
|
# AI state
|
|
'is_ai_speaking': False,
|
|
'should_interrupt_ai': False,
|
|
'ai_stream_queue': queue.Queue(),
|
|
|
|
# Processing flags
|
|
'is_processing': False,
|
|
'pending_user_audio': None
|
|
}
|
|
|
|
# Send config to client
|
|
emit('session_ready', {
|
|
'whisper_available': whisper_model is not None,
|
|
'csm_available': csm_generator is not None,
|
|
'llm_available': llm_model is not None,
|
|
'client_sample_rate': CLIENT_SAMPLE_RATE,
|
|
'server_sample_rate': getattr(csm_generator, 'sample_rate', 24000) if csm_generator else 24000
|
|
})
|
|
|
|
emit('ready_for_speech', {'message': 'Ready to start conversation'})
|
|
|
|
@socketio.on('disconnect')
|
|
def handle_disconnect():
|
|
"""Handle client disconnection"""
|
|
session_id = request.sid
|
|
print(f"Client disconnected: {session_id}")
|
|
|
|
# Clean up resources
|
|
if session_id in user_sessions:
|
|
# Signal any running threads to stop
|
|
user_sessions[session_id]['should_interrupt_ai'] = True
|
|
|
|
# Clean up resources
|
|
del user_sessions[session_id]
|
|
|
|
@socketio.on('webrtc_signal')
|
|
def handle_webrtc_signal(data):
|
|
"""Handle WebRTC signaling for P2P connection establishment"""
|
|
session_id = request.sid
|
|
if session_id not in user_sessions:
|
|
return
|
|
|
|
# Simply relay the signal to the client
|
|
# In a multi-user app, we would route this to the correct peer
|
|
emit('webrtc_signal', data)
|
|
|
|
@socketio.on('webrtc_connected')
|
|
def handle_webrtc_connected(data):
|
|
"""Client notifies that WebRTC connection is established"""
|
|
session_id = request.sid
|
|
if session_id not in user_sessions:
|
|
return
|
|
|
|
user_sessions[session_id]['webrtc_connected'] = True
|
|
print(f"WebRTC connected for session {session_id}")
|
|
emit('ready_for_speech', {'message': 'Ready to start conversation'})
|
|
|
|
@socketio.on('audio_stream')
|
|
def handle_audio_stream(data):
|
|
"""Process incoming audio stream packets from client"""
|
|
session_id = request.sid
|
|
if session_id not in user_sessions:
|
|
return
|
|
|
|
session = user_sessions[session_id]
|
|
|
|
try:
|
|
# Decode audio data
|
|
audio_bytes = base64.b64decode(data.get('audio', ''))
|
|
if not audio_bytes or len(audio_bytes) < 2: # Need at least one sample
|
|
return
|
|
|
|
# Add to current audio buffer
|
|
session['audio_buffer'] += audio_bytes
|
|
|
|
# Check for speech using VAD
|
|
has_speech = detect_speech(audio_bytes, session_id)
|
|
|
|
# Handle speech state machine
|
|
if has_speech:
|
|
# Reset silence tracking when speech is detected
|
|
session['last_vad_active'] = time.time()
|
|
session['silence_duration'] = 0
|
|
session['speech_frames'] += 1
|
|
|
|
# If not already marked as speaking and we have enough speech frames
|
|
if not session['is_user_speaking'] and session['speech_frames'] >= 5:
|
|
on_speech_started(session_id)
|
|
else:
|
|
# No speech detected in this frame
|
|
if session['is_user_speaking']:
|
|
# Calculate silence duration
|
|
now = time.time()
|
|
session['silence_duration'] = now - session['last_vad_active']
|
|
|
|
# If silent for more than 0.5 seconds, end speech segment
|
|
if session['silence_duration'] > 0.8 and session['speech_frames'] > 8:
|
|
on_speech_ended(session_id)
|
|
else:
|
|
# Not speaking and no speech, just a silent frame
|
|
session['speech_frames'] = max(0, session['speech_frames'] - 1)
|
|
|
|
except Exception as e:
|
|
print(f"Error processing audio stream: {e}")
|
|
|
|
def detect_speech(audio_bytes, session_id):
|
|
"""Use VAD to check if audio contains speech"""
|
|
if session_id not in user_sessions:
|
|
return False
|
|
|
|
session = user_sessions[session_id]
|
|
|
|
# Store in VAD buffer for history
|
|
session['vad_buffer'].append(audio_bytes)
|
|
|
|
if vad is None:
|
|
# Fallback to simple energy detection
|
|
audio_data = np.frombuffer(audio_bytes, dtype=np.int16)
|
|
energy = np.mean(np.abs(audio_data)) / 32768.0
|
|
return energy > 0.015 # Simple threshold
|
|
|
|
try:
|
|
# Ensure we have the right amount of data for VAD
|
|
audio_data = np.frombuffer(audio_bytes, dtype=np.int16)
|
|
|
|
# If we have too much data, use just the right amount
|
|
if len(audio_data) >= VAD_FRAME_SIZE:
|
|
frame = audio_data[:VAD_FRAME_SIZE].tobytes()
|
|
return vad.is_speech(frame, SAMPLE_RATE)
|
|
|
|
# If too little data, accumulate in the VAD buffer and check periodically
|
|
if len(session['vad_buffer']) >= 3:
|
|
# Combine recent chunks to get enough data
|
|
combined = bytearray()
|
|
for chunk in list(session['vad_buffer'])[-3:]:
|
|
combined.extend(chunk)
|
|
|
|
# Extract the right amount of data
|
|
if len(combined) >= VAD_FRAME_SIZE:
|
|
frame = combined[:VAD_FRAME_SIZE]
|
|
return vad.is_speech(bytes(frame), SAMPLE_RATE)
|
|
|
|
return False
|
|
|
|
except Exception as e:
|
|
print(f"VAD error: {e}")
|
|
return False
|
|
|
|
def on_speech_started(session_id):
|
|
"""Handle start of user speech"""
|
|
if session_id not in user_sessions:
|
|
return
|
|
|
|
session = user_sessions[session_id]
|
|
|
|
# Reset audio buffer
|
|
session['audio_buffer'] = bytearray()
|
|
session['is_user_speaking'] = True
|
|
session['is_turn_active'] = True
|
|
|
|
# If AI is speaking, we need to interrupt it
|
|
if session['is_ai_speaking']:
|
|
session['should_interrupt_ai'] = True
|
|
socketio.emit('ai_interrupted_by_user', room=session_id)
|
|
|
|
# Notify client that we detected speech
|
|
socketio.emit('user_speech_start', room=session_id)
|
|
|
|
def on_speech_ended(session_id):
|
|
"""Handle end of user speech segment"""
|
|
if session_id not in user_sessions:
|
|
return
|
|
|
|
session = user_sessions[session_id]
|
|
|
|
# Mark as not speaking anymore
|
|
session['is_user_speaking'] = False
|
|
session['speech_frames'] = 0
|
|
|
|
# If no audio or already processing, skip
|
|
if len(session['audio_buffer']) < 4000 or session['is_processing']: # At least 250ms of audio
|
|
session['audio_buffer'] = bytearray()
|
|
return
|
|
|
|
# Mark as processing to prevent multiple processes
|
|
session['is_processing'] = True
|
|
|
|
# Create a copy of the audio buffer
|
|
audio_copy = session['audio_buffer']
|
|
session['audio_buffer'] = bytearray()
|
|
|
|
# Convert audio to the format needed for processing
|
|
try:
|
|
# Convert to float32 between -1 and 1
|
|
audio_np = np.frombuffer(audio_copy, dtype=np.int16).astype(np.float32) / 32768.0
|
|
audio_tensor = torch.from_numpy(audio_np)
|
|
|
|
# Resample to Whisper's expected sample rate if necessary
|
|
if CLIENT_SAMPLE_RATE != WHISPER_SAMPLE_RATE:
|
|
audio_tensor = torchaudio.functional.resample(
|
|
audio_tensor,
|
|
orig_freq=CLIENT_SAMPLE_RATE,
|
|
new_freq=WHISPER_SAMPLE_RATE
|
|
)
|
|
|
|
# Save as WAV for transcription
|
|
temp_audio_path = f"temp_audio_{session_id}.wav"
|
|
torchaudio.save(
|
|
temp_audio_path,
|
|
audio_tensor.unsqueeze(0),
|
|
WHISPER_SAMPLE_RATE
|
|
)
|
|
|
|
# Start transcription and response process in a thread
|
|
threading.Thread(
|
|
target=process_user_utterance,
|
|
args=(session_id, temp_audio_path, audio_tensor),
|
|
daemon=True
|
|
).start()
|
|
|
|
# Notify client that processing has started
|
|
socketio.emit('processing_speech', room=session_id)
|
|
|
|
except Exception as e:
|
|
print(f"Error preparing audio: {e}")
|
|
session['is_processing'] = False
|
|
socketio.emit('error', {'message': f'Error processing audio: {str(e)}'}, room=session_id)
|
|
|
|
def process_user_utterance(session_id, audio_path, audio_tensor):
|
|
"""Process user utterance, transcribe and generate response"""
|
|
if session_id not in user_sessions:
|
|
return
|
|
|
|
session = user_sessions[session_id]
|
|
|
|
try:
|
|
# Transcribe audio
|
|
if whisper_model is not None:
|
|
user_text = transcribe_with_whisper(audio_path)
|
|
else:
|
|
# Fallback to another transcription service
|
|
user_text = transcribe_fallback(audio_path)
|
|
|
|
# Clean up temp file
|
|
if os.path.exists(audio_path):
|
|
os.remove(audio_path)
|
|
|
|
# Check if we got meaningful text
|
|
if not user_text or len(user_text.strip()) < 2:
|
|
socketio.emit('no_speech_detected', room=session_id) # CHANGED: emit → socketio.emit
|
|
session['is_processing'] = False
|
|
return
|
|
|
|
print(f"Transcribed: {user_text}")
|
|
|
|
# Create user segment
|
|
user_segment = Segment(
|
|
text=user_text,
|
|
speaker=0, # User is speaker 0
|
|
audio=audio_tensor
|
|
)
|
|
session['segments'].append(user_segment)
|
|
|
|
# Update conversation history
|
|
session['conversation_history'].append({
|
|
'role': 'user',
|
|
'text': user_text
|
|
})
|
|
|
|
# Send transcription to client
|
|
socketio.emit('transcription', {'text': user_text}, room=session_id) # CHANGED: emit → socketio.emit
|
|
|
|
# Generate AI response
|
|
ai_response = generate_ai_response(user_text, session_id)
|
|
|
|
# Send text response to client
|
|
socketio.emit('ai_response_text', {'text': ai_response}, room=session_id) # CHANGED: emit → socketio.emit
|
|
|
|
# Update conversation history
|
|
session['conversation_history'].append({
|
|
'role': 'assistant',
|
|
'text': ai_response
|
|
})
|
|
|
|
# Generate voice response if CSM is available
|
|
if csm_generator is not None:
|
|
session['is_ai_speaking'] = True
|
|
session['should_interrupt_ai'] = False
|
|
|
|
# Begin streaming audio response
|
|
threading.Thread(
|
|
target=stream_ai_response,
|
|
args=(ai_response, session_id),
|
|
daemon=True
|
|
).start()
|
|
|
|
except Exception as e:
|
|
print(f"Error processing utterance: {e}")
|
|
socketio.emit('error', {'message': f'Error: {str(e)}'}, room=session_id) # CHANGED: emit → socketio.emit
|
|
|
|
finally:
|
|
# Clear processing flag
|
|
if session_id in user_sessions:
|
|
session['is_processing'] = False
|
|
|
|
def transcribe_with_whisper(audio_path):
|
|
"""Transcribe audio using Faster-Whisper"""
|
|
segments, info = whisper_model.transcribe(audio_path, beam_size=5)
|
|
|
|
# Collect all text from segments
|
|
user_text = ""
|
|
for segment in segments:
|
|
user_text += segment.text.strip() + " "
|
|
|
|
return user_text.strip()
|
|
|
|
def transcribe_fallback(audio_path):
|
|
"""Fallback transcription using Google's speech recognition"""
|
|
try:
|
|
import speech_recognition as sr
|
|
recognizer = sr.Recognizer()
|
|
|
|
with sr.AudioFile(audio_path) as source:
|
|
audio = recognizer.record(source)
|
|
try:
|
|
text = recognizer.recognize_google(audio)
|
|
return text
|
|
except sr.UnknownValueError:
|
|
return ""
|
|
except sr.RequestError:
|
|
return "[Speech recognition service unavailable]"
|
|
except ImportError:
|
|
return "[Speech recognition not available]"
|
|
|
|
def generate_ai_response(user_text, session_id):
|
|
"""Generate text response using available LLM"""
|
|
if session_id not in user_sessions:
|
|
return "I'm sorry, your session has expired."
|
|
|
|
session = user_sessions[session_id]
|
|
|
|
if llm_model is not None and llm_tokenizer is not None:
|
|
# Format conversation history for the LLM
|
|
prompt = "You are a helpful, friendly voice assistant. Keep your responses brief and conversational.\n\n"
|
|
|
|
# Add recent conversation history (last 6 turns maximum)
|
|
for entry in session['conversation_history'][-6:]:
|
|
if entry['role'] == 'user':
|
|
prompt += f"User: {entry['text']}\n"
|
|
else:
|
|
prompt += f"Assistant: {entry['text']}\n"
|
|
|
|
# Add current query if not already in history
|
|
if not session['conversation_history'] or session['conversation_history'][-1]['role'] != 'user':
|
|
prompt += f"User: {user_text}\n"
|
|
|
|
prompt += "Assistant: "
|
|
|
|
try:
|
|
# Generate response
|
|
inputs = llm_tokenizer(prompt, return_tensors="pt").to(device)
|
|
output = llm_model.generate(
|
|
inputs.input_ids,
|
|
max_new_tokens=100, # Keep responses shorter for voice
|
|
temperature=0.7,
|
|
top_p=0.9,
|
|
do_sample=True
|
|
)
|
|
|
|
response = llm_tokenizer.decode(output[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
|
|
return response.strip()
|
|
|
|
except Exception as e:
|
|
print(f"Error generating LLM response: {e}")
|
|
return fallback_response(user_text)
|
|
else:
|
|
return fallback_response(user_text)
|
|
|
|
def fallback_response(user_text):
|
|
"""Generate simple fallback responses when LLM is unavailable"""
|
|
user_text_lower = user_text.lower()
|
|
|
|
if "hello" in user_text_lower or "hi" in user_text_lower:
|
|
return "Hello! How can I help you today?"
|
|
|
|
elif "how are you" in user_text_lower:
|
|
return "I'm doing well, thanks for asking! How about you?"
|
|
|
|
elif "thank" in user_text_lower:
|
|
return "You're welcome! Happy to help."
|
|
|
|
elif "bye" in user_text_lower or "goodbye" in user_text_lower:
|
|
return "Goodbye! Have a great day!"
|
|
|
|
elif any(q in user_text_lower for q in ["what", "who", "where", "when", "why", "how"]):
|
|
return "That's an interesting question. I wish I could provide a better answer in my current fallback mode."
|
|
|
|
else:
|
|
return "I see. Tell me more about that."
|
|
|
|
def stream_ai_response(text, session_id):
|
|
"""Generate and stream audio response in real-time chunks"""
|
|
if session_id not in user_sessions:
|
|
return
|
|
|
|
session = user_sessions[session_id]
|
|
|
|
try:
|
|
# Signal start of AI speech
|
|
emit('ai_speech_start', room=session_id)
|
|
|
|
# Use the last few conversation segments as context (up to 4)
|
|
context_segments = session['segments'][-4:] if len(session['segments']) > 4 else session['segments']
|
|
|
|
# Generate audio for bot response
|
|
audio = csm_generator.generate(
|
|
text=text,
|
|
speaker=1, # Bot is speaker 1
|
|
context=context_segments,
|
|
max_audio_length_ms=10000, # 10 seconds max
|
|
temperature=0.9,
|
|
topk=50
|
|
)
|
|
|
|
# Create and store bot segment
|
|
bot_segment = Segment(
|
|
text=text,
|
|
speaker=1,
|
|
audio=audio
|
|
)
|
|
|
|
if session_id in user_sessions:
|
|
session['segments'].append(bot_segment)
|
|
|
|
# Stream audio in small chunks for more responsive playback
|
|
chunk_size = AUDIO_CHUNK_SIZE # Size defined in constants
|
|
|
|
for i in range(0, len(audio), chunk_size):
|
|
# Check if we should stop (user interrupted)
|
|
if session_id not in user_sessions or session['should_interrupt_ai']:
|
|
print("AI speech interrupted")
|
|
break
|
|
|
|
# Get next chunk
|
|
chunk = audio[i:i+chunk_size]
|
|
|
|
# Convert audio chunk to base64 for streaming
|
|
audio_bytes = io.BytesIO()
|
|
torchaudio.save(audio_bytes, chunk.unsqueeze(0).cpu(), csm_generator.sample_rate, format="wav")
|
|
audio_bytes.seek(0)
|
|
audio_b64 = base64.b64encode(audio_bytes.read()).decode('utf-8')
|
|
|
|
# Send chunk to client
|
|
socketio.emit('ai_speech_chunk', {
|
|
'audio': audio_b64,
|
|
'is_last': i + chunk_size >= len(audio)
|
|
}, room=session_id)
|
|
|
|
# Small sleep for more natural pacing
|
|
time.sleep(0.06) # Slight delay for smoother playback
|
|
|
|
# Signal end of AI speech
|
|
if session_id in user_sessions:
|
|
session['is_ai_speaking'] = False
|
|
session['is_turn_active'] = False # End conversation turn
|
|
socketio.emit('ai_speech_end', room=session_id)
|
|
|
|
except Exception as e:
|
|
print(f"Error streaming AI response: {e}")
|
|
if session_id in user_sessions:
|
|
session['is_ai_speaking'] = False
|
|
session['is_turn_active'] = False
|
|
socketio.emit('error', {'message': f'Error generating audio: {str(e)}'}, room=session_id)
|
|
socketio.emit('ai_speech_end', room=session_id)
|
|
|
|
@socketio.on('interrupt_ai')
|
|
def handle_interrupt():
|
|
"""Handle explicit AI interruption request from client"""
|
|
session_id = request.sid
|
|
if session_id in user_sessions:
|
|
user_sessions[session_id]['should_interrupt_ai'] = True
|
|
emit('ai_interrupted', room=session_id)
|
|
|
|
@socketio.on('get_config')
|
|
def handle_get_config():
|
|
"""Send configuration to client"""
|
|
session_id = request.sid
|
|
if session_id in user_sessions:
|
|
emit('config', {
|
|
'client_sample_rate': CLIENT_SAMPLE_RATE,
|
|
'server_sample_rate': getattr(csm_generator, 'sample_rate', 24000) if csm_generator else 24000,
|
|
'whisper_available': whisper_model is not None,
|
|
'csm_available': csm_generator is not None,
|
|
'ice_servers': ICE_SERVERS
|
|
})
|
|
|
|
if __name__ == '__main__':
|
|
# Ensure the existing index.html file is in the correct location
|
|
if not os.path.exists('templates'):
|
|
os.makedirs('templates')
|
|
|
|
if os.path.exists('index.html') and not os.path.exists('templates/index.html'):
|
|
os.rename('index.html', 'templates/index.html')
|
|
|
|
# Load models before starting the server
|
|
print("Starting model loading...")
|
|
load_models()
|
|
|
|
# Start the server
|
|
print("Starting Flask SocketIO server...")
|
|
socketio.run(app, host='0.0.0.0', port=5000, debug=False) |