543 lines
20 KiB
Python
543 lines
20 KiB
Python
import os
|
|
import io
|
|
import base64
|
|
import time
|
|
import torch
|
|
import torchaudio
|
|
import numpy as np
|
|
from flask import Flask, render_template, request
|
|
from flask_socketio import SocketIO, emit
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
from collections import deque
|
|
import requests
|
|
import huggingface_hub
|
|
from generator import load_csm_1b, Segment
|
|
import threading
|
|
import queue
|
|
import asyncio
|
|
import json
|
|
|
|
# Configure environment with longer timeouts
|
|
os.environ["HF_HUB_DOWNLOAD_TIMEOUT"] = "600" # 10 minutes timeout for downloads
|
|
requests.adapters.DEFAULT_TIMEOUT = 60 # Increase default requests timeout
|
|
|
|
# Create a models directory for caching
|
|
os.makedirs("models", exist_ok=True)
|
|
|
|
app = Flask(__name__)
|
|
app.config['SECRET_KEY'] = 'your-secret-key'
|
|
socketio = SocketIO(app, cors_allowed_origins="*", async_mode='eventlet')
|
|
|
|
# Explicitly check for CUDA and print more detailed info
|
|
print("\n=== CUDA Information ===")
|
|
if torch.cuda.is_available():
|
|
print(f"CUDA is available")
|
|
print(f"CUDA version: {torch.version.cuda}")
|
|
print(f"Number of GPUs: {torch.cuda.device_count()}")
|
|
for i in range(torch.cuda.device_count()):
|
|
print(f"GPU {i}: {torch.cuda.get_device_name(i)}")
|
|
else:
|
|
print("CUDA is not available")
|
|
|
|
# Check for cuDNN
|
|
try:
|
|
import ctypes
|
|
ctypes.CDLL("libcudnn_ops_infer.so.8")
|
|
print("cuDNN is available")
|
|
except:
|
|
print("cuDNN is not available (libcudnn_ops_infer.so.8 not found)")
|
|
|
|
# Check for other compute platforms
|
|
if torch.backends.mps.is_available():
|
|
print("MPS (Apple Silicon) is available")
|
|
else:
|
|
print("MPS is not available")
|
|
print("========================\n")
|
|
|
|
# Check for CUDA availability and handle potential CUDA/cuDNN issues
|
|
try:
|
|
if torch.cuda.is_available():
|
|
# Try to initialize CUDA to check if libraries are properly loaded
|
|
_ = torch.zeros(1).cuda()
|
|
device = "cuda"
|
|
whisper_compute_type = "float16"
|
|
print("🟢 CUDA is available and initialized successfully")
|
|
elif torch.backends.mps.is_available():
|
|
device = "mps"
|
|
whisper_compute_type = "float32"
|
|
print("🟢 MPS is available (Apple Silicon)")
|
|
else:
|
|
device = "cpu"
|
|
whisper_compute_type = "int8"
|
|
print("🟡 Using CPU (CUDA/MPS not available)")
|
|
except Exception as e:
|
|
print(f"🔴 Error initializing CUDA: {e}")
|
|
print("🔴 Falling back to CPU")
|
|
device = "cpu"
|
|
whisper_compute_type = "int8"
|
|
|
|
print(f"Using device: {device}")
|
|
|
|
# Initialize models with proper error handling
|
|
whisper_model = None
|
|
csm_generator = None
|
|
llm_model = None
|
|
llm_tokenizer = None
|
|
|
|
def load_models():
|
|
global whisper_model, csm_generator, llm_model, llm_tokenizer
|
|
|
|
# Initialize Faster-Whisper for transcription
|
|
try:
|
|
print("Loading Whisper model...")
|
|
# Import here to avoid immediate import errors if package is missing
|
|
from faster_whisper import WhisperModel
|
|
whisper_model = WhisperModel("base", device=device, compute_type=whisper_compute_type, download_root="./models/whisper")
|
|
print("Whisper model loaded successfully")
|
|
except Exception as e:
|
|
print(f"Error loading Whisper model: {e}")
|
|
print("Will use backup speech recognition method if available")
|
|
|
|
# Initialize CSM model for audio generation
|
|
try:
|
|
print("Loading CSM model...")
|
|
csm_generator = load_csm_1b(device=device)
|
|
print("CSM model loaded successfully")
|
|
except Exception as e:
|
|
print(f"Error loading CSM model: {e}")
|
|
print("Audio generation will not be available")
|
|
|
|
# Initialize Llama 3.2 model for response generation
|
|
try:
|
|
print("Loading Llama 3.2 model...")
|
|
llm_model_id = "meta-llama/Llama-3.2-1B" # Choose appropriate size based on resources
|
|
llm_tokenizer = AutoTokenizer.from_pretrained(llm_model_id, cache_dir="./models/llama")
|
|
# Use the right data type based on device
|
|
dtype = torch.bfloat16 if device != "cpu" else torch.float32
|
|
llm_model = AutoModelForCausalLM.from_pretrained(
|
|
llm_model_id,
|
|
torch_dtype=dtype,
|
|
device_map=device,
|
|
cache_dir="./models/llama",
|
|
low_cpu_mem_usage=True
|
|
)
|
|
print("Llama 3.2 model loaded successfully")
|
|
except Exception as e:
|
|
print(f"Error loading Llama 3.2 model: {e}")
|
|
print("Will use a fallback response generation method")
|
|
|
|
# Store conversation context
|
|
conversation_context = {} # session_id -> context
|
|
active_audio_streams = {} # session_id -> stream status
|
|
|
|
@app.route('/')
|
|
def index():
|
|
return render_template('index.html')
|
|
|
|
@socketio.on('connect')
|
|
def handle_connect():
|
|
print(f"Client connected: {request.sid}")
|
|
conversation_context[request.sid] = {
|
|
'segments': [],
|
|
'speakers': [0, 1], # 0 = user, 1 = bot
|
|
'audio_buffer': deque(maxlen=10), # Store recent audio chunks
|
|
'is_speaking': False,
|
|
'last_activity': time.time(),
|
|
'active_session': True,
|
|
'transcription_buffer': [] # For real-time transcription
|
|
}
|
|
emit('ready', {
|
|
'message': 'Connection established',
|
|
'sample_rate': getattr(csm_generator, 'sample_rate', 24000) if csm_generator else 24000
|
|
})
|
|
|
|
@socketio.on('disconnect')
|
|
def handle_disconnect():
|
|
print(f"Client disconnected: {request.sid}")
|
|
session_id = request.sid
|
|
|
|
# Clean up resources
|
|
if session_id in conversation_context:
|
|
conversation_context[session_id]['active_session'] = False
|
|
del conversation_context[session_id]
|
|
|
|
if session_id in active_audio_streams:
|
|
active_audio_streams[session_id]['active'] = False
|
|
del active_audio_streams[session_id]
|
|
|
|
@socketio.on('audio_stream')
|
|
def handle_audio_stream(data):
|
|
"""Handle incoming audio stream from client"""
|
|
session_id = request.sid
|
|
|
|
if session_id not in conversation_context:
|
|
return
|
|
|
|
context = conversation_context[session_id]
|
|
context['last_activity'] = time.time()
|
|
|
|
# Process different stream events
|
|
if data.get('event') == 'start':
|
|
# Client is starting to send audio
|
|
context['is_speaking'] = True
|
|
context['audio_buffer'].clear()
|
|
context['transcription_buffer'] = []
|
|
print(f"User {session_id} started streaming audio")
|
|
|
|
# If AI was speaking, interrupt it
|
|
if session_id in active_audio_streams and active_audio_streams[session_id]['active']:
|
|
active_audio_streams[session_id]['active'] = False
|
|
emit('ai_stream_interrupt', {}, room=session_id)
|
|
|
|
elif data.get('event') == 'data':
|
|
# Audio data received
|
|
if not context['is_speaking']:
|
|
return
|
|
|
|
# Decode audio chunk
|
|
try:
|
|
audio_data = base64.b64decode(data.get('audio', ''))
|
|
if not audio_data:
|
|
return
|
|
|
|
audio_numpy = np.frombuffer(audio_data, dtype=np.float32)
|
|
|
|
# Apply a simple noise gate
|
|
if np.mean(np.abs(audio_numpy)) < 0.01: # Very quiet
|
|
return
|
|
|
|
audio_tensor = torch.tensor(audio_numpy)
|
|
|
|
# Add to audio buffer
|
|
context['audio_buffer'].append(audio_tensor)
|
|
|
|
# Real-time transcription (periodic)
|
|
if len(context['audio_buffer']) % 3 == 0: # Process every 3 chunks
|
|
threading.Thread(
|
|
target=process_realtime_transcription,
|
|
args=(session_id,),
|
|
daemon=True
|
|
).start()
|
|
except Exception as e:
|
|
print(f"Error processing audio chunk: {e}")
|
|
|
|
elif data.get('event') == 'end':
|
|
# Client has finished sending audio
|
|
context['is_speaking'] = False
|
|
|
|
if len(context['audio_buffer']) > 0:
|
|
# Process the complete utterance
|
|
threading.Thread(
|
|
target=process_complete_utterance,
|
|
args=(session_id,),
|
|
daemon=True
|
|
).start()
|
|
|
|
print(f"User {session_id} stopped streaming audio")
|
|
|
|
def process_realtime_transcription(session_id):
|
|
"""Process incoming audio for real-time transcription"""
|
|
if session_id not in conversation_context or not conversation_context[session_id]['active_session']:
|
|
return
|
|
|
|
context = conversation_context[session_id]
|
|
|
|
if not context['audio_buffer'] or not context['is_speaking']:
|
|
return
|
|
|
|
try:
|
|
# Combine current buffer for transcription
|
|
buffer_copy = list(context['audio_buffer'])
|
|
if not buffer_copy:
|
|
return
|
|
|
|
full_audio = torch.cat(buffer_copy, dim=0)
|
|
|
|
# Save audio to temporary WAV file for transcription
|
|
temp_audio_path = f"temp_rt_{session_id}.wav"
|
|
torchaudio.save(
|
|
temp_audio_path,
|
|
full_audio.unsqueeze(0),
|
|
44100 # Assuming 44.1kHz from client
|
|
)
|
|
|
|
# Transcribe with Whisper if available
|
|
if whisper_model is not None:
|
|
segments, _ = whisper_model.transcribe(temp_audio_path, beam_size=5)
|
|
text = " ".join([segment.text for segment in segments])
|
|
|
|
if text.strip():
|
|
context['transcription_buffer'].append(text)
|
|
# Send partial transcription to client
|
|
emit('partial_transcription', {'text': text}, room=session_id)
|
|
except Exception as e:
|
|
print(f"Error in realtime transcription: {e}")
|
|
finally:
|
|
# Clean up
|
|
if os.path.exists(temp_audio_path):
|
|
os.remove(temp_audio_path)
|
|
|
|
def process_complete_utterance(session_id):
|
|
"""Process completed user utterance, generate response and stream audio back"""
|
|
if session_id not in conversation_context or not conversation_context[session_id]['active_session']:
|
|
return
|
|
|
|
context = conversation_context[session_id]
|
|
|
|
if not context['audio_buffer']:
|
|
return
|
|
|
|
# Combine audio chunks
|
|
full_audio = torch.cat(list(context['audio_buffer']), dim=0)
|
|
context['audio_buffer'].clear()
|
|
|
|
# Save audio to temporary WAV file for transcription
|
|
temp_audio_path = f"temp_audio_{session_id}.wav"
|
|
torchaudio.save(
|
|
temp_audio_path,
|
|
full_audio.unsqueeze(0),
|
|
44100 # Assuming 44.1kHz from client
|
|
)
|
|
|
|
try:
|
|
# Try using Whisper first if available
|
|
if whisper_model is not None:
|
|
user_text = transcribe_with_whisper(temp_audio_path)
|
|
else:
|
|
# Fallback to Google's speech recognition
|
|
user_text = transcribe_with_google(temp_audio_path)
|
|
|
|
if not user_text:
|
|
print("No speech detected.")
|
|
emit('error', {'message': 'No speech detected. Please try again.'}, room=session_id)
|
|
return
|
|
|
|
print(f"Transcribed: {user_text}")
|
|
|
|
# Add to conversation segments
|
|
user_segment = Segment(
|
|
text=user_text,
|
|
speaker=0, # User is speaker 0
|
|
audio=full_audio
|
|
)
|
|
context['segments'].append(user_segment)
|
|
|
|
# Generate bot response text
|
|
bot_response = generate_llm_response(user_text, context['segments'])
|
|
print(f"Bot response: {bot_response}")
|
|
|
|
# Send transcribed text to client
|
|
emit('transcription', {'text': user_text}, room=session_id)
|
|
|
|
# Generate and stream audio response if CSM is available
|
|
if csm_generator is not None:
|
|
# Create stream state object
|
|
active_audio_streams[session_id] = {
|
|
'active': True,
|
|
'text': bot_response
|
|
}
|
|
|
|
# Send initial response to prepare client
|
|
emit('ai_stream_start', {
|
|
'text': bot_response
|
|
}, room=session_id)
|
|
|
|
# Start audio generation in a separate thread
|
|
threading.Thread(
|
|
target=generate_and_stream_audio_realtime,
|
|
args=(bot_response, context['segments'], session_id),
|
|
daemon=True
|
|
).start()
|
|
else:
|
|
# Send text-only response if audio generation isn't available
|
|
emit('text_response', {'text': bot_response}, room=session_id)
|
|
|
|
# Add text-only bot response to conversation history
|
|
bot_segment = Segment(
|
|
text=bot_response,
|
|
speaker=1, # Bot is speaker 1
|
|
audio=torch.zeros(1) # Placeholder empty audio
|
|
)
|
|
context['segments'].append(bot_segment)
|
|
|
|
except Exception as e:
|
|
print(f"Error processing speech: {e}")
|
|
emit('error', {'message': f'Error processing speech: {str(e)}'}, room=session_id)
|
|
finally:
|
|
# Cleanup temp file
|
|
if os.path.exists(temp_audio_path):
|
|
os.remove(temp_audio_path)
|
|
|
|
def transcribe_with_whisper(audio_path):
|
|
"""Transcribe audio using Faster-Whisper"""
|
|
segments, info = whisper_model.transcribe(audio_path, beam_size=5)
|
|
|
|
# Collect all text from segments
|
|
user_text = ""
|
|
for segment in segments:
|
|
segment_text = segment.text.strip()
|
|
print(f"[{segment.start:.2f}s -> {segment.end:.2f}s] {segment_text}")
|
|
user_text += segment_text + " "
|
|
|
|
print(f"Transcribed text: {user_text.strip()}")
|
|
|
|
return user_text.strip()
|
|
|
|
def transcribe_with_google(audio_path):
|
|
"""Fallback transcription using Google's speech recognition"""
|
|
import speech_recognition as sr
|
|
recognizer = sr.Recognizer()
|
|
|
|
with sr.AudioFile(audio_path) as source:
|
|
audio = recognizer.record(source)
|
|
try:
|
|
text = recognizer.recognize_google(audio)
|
|
return text
|
|
except sr.UnknownValueError:
|
|
return ""
|
|
except sr.RequestError:
|
|
# If Google API fails, try a basic energy-based VAD approach
|
|
# This is a very basic fallback and won't give good results
|
|
return "[Speech detected but transcription failed]"
|
|
|
|
def generate_llm_response(user_text, conversation_segments):
|
|
"""Generate text response using available model"""
|
|
if llm_model is not None and llm_tokenizer is not None:
|
|
# Format conversation history for the LLM
|
|
conversation_history = ""
|
|
for segment in conversation_segments[-5:]: # Use last 5 utterances for context
|
|
speaker_name = "User" if segment.speaker == 0 else "Assistant"
|
|
conversation_history += f"{speaker_name}: {segment.text}\n"
|
|
|
|
# Add the current user query
|
|
conversation_history += f"User: {user_text}\nAssistant:"
|
|
|
|
try:
|
|
# Generate response
|
|
inputs = llm_tokenizer(conversation_history, return_tensors="pt").to(device)
|
|
output = llm_model.generate(
|
|
inputs.input_ids,
|
|
max_new_tokens=150,
|
|
temperature=0.7,
|
|
top_p=0.9,
|
|
do_sample=True
|
|
)
|
|
|
|
response = llm_tokenizer.decode(output[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
|
|
return response.strip()
|
|
except Exception as e:
|
|
print(f"Error generating response with LLM: {e}")
|
|
return fallback_response(user_text)
|
|
else:
|
|
return fallback_response(user_text)
|
|
|
|
def fallback_response(user_text):
|
|
"""Generate a simple fallback response when LLM is not available"""
|
|
# Simple rule-based responses
|
|
user_text_lower = user_text.lower()
|
|
|
|
if "hello" in user_text_lower or "hi" in user_text_lower:
|
|
return "Hello! I'm a simple fallback assistant. The main language model couldn't be loaded, so I have limited capabilities."
|
|
|
|
elif "how are you" in user_text_lower:
|
|
return "I'm functioning within my limited capabilities. How can I assist you today?"
|
|
|
|
elif "thank" in user_text_lower:
|
|
return "You're welcome! Let me know if there's anything else I can help with."
|
|
|
|
elif "bye" in user_text_lower or "goodbye" in user_text_lower:
|
|
return "Goodbye! Have a great day!"
|
|
|
|
elif any(q in user_text_lower for q in ["what", "who", "where", "when", "why", "how"]):
|
|
return "I'm running in fallback mode and can't answer complex questions. Please try again when the main language model is available."
|
|
|
|
else:
|
|
return "I understand you said something about that. Unfortunately, I'm running in fallback mode with limited capabilities. Please try again later when the main model is available."
|
|
|
|
def generate_and_stream_audio_realtime(text, conversation_segments, session_id):
|
|
"""Generate audio response using CSM and stream it in real-time to client"""
|
|
if session_id not in active_audio_streams or not active_audio_streams[session_id]['active']:
|
|
return
|
|
|
|
try:
|
|
# Use the last few conversation segments as context
|
|
context_segments = conversation_segments[-4:] if len(conversation_segments) > 4 else conversation_segments
|
|
|
|
# Generate audio for bot response
|
|
audio = csm_generator.generate(
|
|
text=text,
|
|
speaker=1, # Bot is speaker 1
|
|
context=context_segments,
|
|
max_audio_length_ms=10000, # 10 seconds max
|
|
temperature=0.9,
|
|
topk=50
|
|
)
|
|
|
|
# Store the full audio for conversation history
|
|
bot_segment = Segment(
|
|
text=text,
|
|
speaker=1, # Bot is speaker 1
|
|
audio=audio
|
|
)
|
|
if session_id in conversation_context and conversation_context[session_id]['active_session']:
|
|
conversation_context[session_id]['segments'].append(bot_segment)
|
|
|
|
# Stream audio in small chunks for more responsive playback
|
|
chunk_size = 4800 # 200ms at 24kHz
|
|
|
|
for i in range(0, len(audio), chunk_size):
|
|
if session_id not in active_audio_streams or not active_audio_streams[session_id]['active']:
|
|
print("Audio streaming interrupted or session ended")
|
|
break
|
|
|
|
chunk = audio[i:i+chunk_size]
|
|
|
|
# Convert audio chunk to base64 for streaming
|
|
audio_bytes = io.BytesIO()
|
|
torchaudio.save(audio_bytes, chunk.unsqueeze(0).cpu(), csm_generator.sample_rate, format="wav")
|
|
audio_bytes.seek(0)
|
|
audio_b64 = base64.b64encode(audio_bytes.read()).decode('utf-8')
|
|
|
|
# Send chunk to client
|
|
socketio.emit('ai_stream_data', {
|
|
'audio': audio_b64,
|
|
'is_last': i + chunk_size >= len(audio)
|
|
}, room=session_id)
|
|
|
|
# Simulate real-time speech by adding a small delay
|
|
# Remove this in production for faster response
|
|
time.sleep(0.15) # Slight delay for more natural timing
|
|
|
|
# Signal end of stream
|
|
if session_id in active_audio_streams and active_audio_streams[session_id]['active']:
|
|
socketio.emit('ai_stream_end', {}, room=session_id)
|
|
active_audio_streams[session_id]['active'] = False
|
|
|
|
except Exception as e:
|
|
print(f"Error generating or streaming audio: {e}")
|
|
# Send error message to client
|
|
if session_id in conversation_context and conversation_context[session_id]['active_session']:
|
|
socketio.emit('error', {
|
|
'message': f'Error generating audio: {str(e)}'
|
|
}, room=session_id)
|
|
|
|
# Signal stream end to unblock client
|
|
socketio.emit('ai_stream_end', {}, room=session_id)
|
|
if session_id in active_audio_streams:
|
|
active_audio_streams[session_id]['active'] = False
|
|
|
|
if __name__ == '__main__':
|
|
# Ensure the existing index.html file is in the correct location
|
|
if not os.path.exists('templates'):
|
|
os.makedirs('templates')
|
|
|
|
if os.path.exists('index.html') and not os.path.exists('templates/index.html'):
|
|
os.rename('index.html', 'templates/index.html')
|
|
|
|
# Load models before starting the server
|
|
print("Starting model loading...")
|
|
load_models()
|
|
|
|
# Start the server with eventlet for better WebSocket performance
|
|
print("Starting Flask SocketIO server...")
|
|
socketio.run(app, host='0.0.0.0', port=5000, debug=False) |