Demo Fixes 10
This commit is contained in:
@@ -8,9 +8,17 @@ import numpy as np
|
||||
from flask import Flask, render_template, request
|
||||
from flask_socketio import SocketIO, emit
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from faster_whisper import WhisperModel
|
||||
from generator import load_csm_1b, Segment
|
||||
from collections import deque
|
||||
import requests
|
||||
import huggingface_hub
|
||||
from generator import load_csm_1b, Segment
|
||||
|
||||
# Configure environment with longer timeouts
|
||||
os.environ["HF_HUB_DOWNLOAD_TIMEOUT"] = "600" # 10 minutes timeout for downloads
|
||||
requests.adapters.DEFAULT_TIMEOUT = 60 # Increase default requests timeout
|
||||
|
||||
# Create a models directory for caching
|
||||
os.makedirs("models", exist_ok=True)
|
||||
|
||||
app = Flask(__name__)
|
||||
app.config['SECRET_KEY'] = 'your-secret-key'
|
||||
@@ -29,23 +37,50 @@ else:
|
||||
|
||||
print(f"Using device: {device}")
|
||||
|
||||
# Initialize models with proper error handling
|
||||
whisper_model = None
|
||||
csm_generator = None
|
||||
llm_model = None
|
||||
llm_tokenizer = None
|
||||
|
||||
def load_models():
|
||||
global whisper_model, csm_generator, llm_model, llm_tokenizer
|
||||
|
||||
# Initialize Faster-Whisper for transcription
|
||||
try:
|
||||
print("Loading Whisper model...")
|
||||
whisper_model = WhisperModel("base", device=device, compute_type=whisper_compute_type)
|
||||
# Import here to avoid immediate import errors if package is missing
|
||||
from faster_whisper import WhisperModel
|
||||
whisper_model = WhisperModel("base", device=device, compute_type=whisper_compute_type, download_root="./models/whisper")
|
||||
print("Whisper model loaded successfully")
|
||||
except Exception as e:
|
||||
print(f"Error loading Whisper model: {e}")
|
||||
print("Will use backup speech recognition method if available")
|
||||
|
||||
# Initialize CSM model for audio generation
|
||||
try:
|
||||
print("Loading CSM model...")
|
||||
csm_generator = load_csm_1b(device=device)
|
||||
print("CSM model loaded successfully")
|
||||
except Exception as e:
|
||||
print(f"Error loading CSM model: {e}")
|
||||
print("Audio generation will not be available")
|
||||
|
||||
# Initialize Llama 3.2 model for response generation
|
||||
try:
|
||||
print("Loading Llama 3.2 model...")
|
||||
llm_model_id = "meta-llama/Llama-3.2-1B" # Choose appropriate size based on resources
|
||||
llm_tokenizer = AutoTokenizer.from_pretrained(llm_model_id)
|
||||
llm_tokenizer = AutoTokenizer.from_pretrained(llm_model_id, cache_dir="./models/llama")
|
||||
llm_model = AutoModelForCausalLM.from_pretrained(
|
||||
llm_model_id,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map=device
|
||||
device_map=device,
|
||||
cache_dir="./models/llama"
|
||||
)
|
||||
print("Llama 3.2 model loaded successfully")
|
||||
except Exception as e:
|
||||
print(f"Error loading Llama 3.2 model: {e}")
|
||||
print("Will use a fallback response generation method")
|
||||
|
||||
# Store conversation context
|
||||
conversation_context = {} # session_id -> context
|
||||
@@ -128,7 +163,7 @@ def process_user_utterance(session_id):
|
||||
context['is_speaking'] = False
|
||||
context['silence_start'] = None
|
||||
|
||||
# Save audio to temporary WAV file for Whisper transcription
|
||||
# Save audio to temporary WAV file for transcription
|
||||
temp_audio_path = f"temp_audio_{session_id}.wav"
|
||||
torchaudio.save(
|
||||
temp_audio_path,
|
||||
@@ -136,25 +171,17 @@ def process_user_utterance(session_id):
|
||||
44100 # Assuming 44.1kHz from client
|
||||
)
|
||||
|
||||
# Transcribe speech using Faster-Whisper
|
||||
try:
|
||||
segments, info = whisper_model.transcribe(temp_audio_path, beam_size=5)
|
||||
|
||||
# Collect all text from segments
|
||||
user_text = ""
|
||||
for segment in segments:
|
||||
segment_text = segment.text.strip()
|
||||
print(f"[{segment.start:.2f}s -> {segment.end:.2f}s] {segment_text}")
|
||||
user_text += segment_text + " "
|
||||
|
||||
user_text = user_text.strip()
|
||||
|
||||
# Cleanup temp file
|
||||
if os.path.exists(temp_audio_path):
|
||||
os.remove(temp_audio_path)
|
||||
# Try using Whisper first if available
|
||||
if whisper_model is not None:
|
||||
user_text = transcribe_with_whisper(temp_audio_path)
|
||||
else:
|
||||
# Fallback to Google's speech recognition
|
||||
user_text = transcribe_with_google(temp_audio_path)
|
||||
|
||||
if not user_text:
|
||||
print("No speech detected.")
|
||||
emit('error', {'message': 'No speech detected. Please try again.'}, room=session_id)
|
||||
return
|
||||
|
||||
print(f"Transcribed: {user_text}")
|
||||
@@ -171,6 +198,11 @@ def process_user_utterance(session_id):
|
||||
bot_response = generate_llm_response(user_text, context['segments'])
|
||||
print(f"Bot response: {bot_response}")
|
||||
|
||||
# Send transcribed text to client
|
||||
emit('transcription', {'text': user_text}, room=session_id)
|
||||
|
||||
# Generate and send audio response if CSM is available
|
||||
if csm_generator is not None:
|
||||
# Convert to audio using CSM
|
||||
bot_audio = generate_audio_response(bot_response, context['segments'])
|
||||
|
||||
@@ -188,24 +220,64 @@ def process_user_utterance(session_id):
|
||||
)
|
||||
context['segments'].append(bot_segment)
|
||||
|
||||
# Send transcribed text to client
|
||||
emit('transcription', {'text': user_text}, room=session_id)
|
||||
|
||||
# Send audio response to client
|
||||
emit('audio_response', {
|
||||
'audio': audio_b64,
|
||||
'text': bot_response
|
||||
}, room=session_id)
|
||||
else:
|
||||
# Send text-only response if audio generation isn't available
|
||||
emit('text_response', {'text': bot_response}, room=session_id)
|
||||
|
||||
# Add text-only bot response to conversation history
|
||||
bot_segment = Segment(
|
||||
text=bot_response,
|
||||
speaker=1, # Bot is speaker 1
|
||||
audio=torch.zeros(1) # Placeholder empty audio
|
||||
)
|
||||
context['segments'].append(bot_segment)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing speech: {e}")
|
||||
emit('error', {'message': f'Error processing speech: {str(e)}'}, room=session_id)
|
||||
# Cleanup temp file in case of error
|
||||
finally:
|
||||
# Cleanup temp file
|
||||
if os.path.exists(temp_audio_path):
|
||||
os.remove(temp_audio_path)
|
||||
|
||||
def transcribe_with_whisper(audio_path):
|
||||
"""Transcribe audio using Faster-Whisper"""
|
||||
segments, info = whisper_model.transcribe(audio_path, beam_size=5)
|
||||
|
||||
# Collect all text from segments
|
||||
user_text = ""
|
||||
for segment in segments:
|
||||
segment_text = segment.text.strip()
|
||||
print(f"[{segment.start:.2f}s -> {segment.end:.2f}s] {segment_text}")
|
||||
user_text += segment_text + " "
|
||||
|
||||
return user_text.strip()
|
||||
|
||||
def transcribe_with_google(audio_path):
|
||||
"""Fallback transcription using Google's speech recognition"""
|
||||
import speech_recognition as sr
|
||||
recognizer = sr.Recognizer()
|
||||
|
||||
with sr.AudioFile(audio_path) as source:
|
||||
audio = recognizer.record(source)
|
||||
try:
|
||||
text = recognizer.recognize_google(audio)
|
||||
return text
|
||||
except sr.UnknownValueError:
|
||||
return ""
|
||||
except sr.RequestError:
|
||||
# If Google API fails, try a basic energy-based VAD approach
|
||||
# This is a very basic fallback and won't give good results
|
||||
return "[Speech detected but transcription failed]"
|
||||
|
||||
def generate_llm_response(user_text, conversation_segments):
|
||||
"""Generate text response using Llama 3.2"""
|
||||
"""Generate text response using available model"""
|
||||
if llm_model is not None and llm_tokenizer is not None:
|
||||
# Format conversation history for the LLM
|
||||
conversation_history = ""
|
||||
for segment in conversation_segments[-5:]: # Use last 5 utterances for context
|
||||
@@ -215,6 +287,7 @@ def generate_llm_response(user_text, conversation_segments):
|
||||
# Add the current user query
|
||||
conversation_history += f"User: {user_text}\nAssistant:"
|
||||
|
||||
try:
|
||||
# Generate response
|
||||
inputs = llm_tokenizer(conversation_history, return_tensors="pt").to(device)
|
||||
output = llm_model.generate(
|
||||
@@ -227,9 +300,38 @@ def generate_llm_response(user_text, conversation_segments):
|
||||
|
||||
response = llm_tokenizer.decode(output[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
|
||||
return response.strip()
|
||||
except Exception as e:
|
||||
print(f"Error generating response with LLM: {e}")
|
||||
return fallback_response(user_text)
|
||||
else:
|
||||
return fallback_response(user_text)
|
||||
|
||||
def fallback_response(user_text):
|
||||
"""Generate a simple fallback response when LLM is not available"""
|
||||
# Simple rule-based responses
|
||||
user_text_lower = user_text.lower()
|
||||
|
||||
if "hello" in user_text_lower or "hi" in user_text_lower:
|
||||
return "Hello! I'm a simple fallback assistant. The main language model couldn't be loaded, so I have limited capabilities."
|
||||
|
||||
elif "how are you" in user_text_lower:
|
||||
return "I'm functioning within my limited capabilities. How can I assist you today?"
|
||||
|
||||
elif "thank" in user_text_lower:
|
||||
return "You're welcome! Let me know if there's anything else I can help with."
|
||||
|
||||
elif "bye" in user_text_lower or "goodbye" in user_text_lower:
|
||||
return "Goodbye! Have a great day!"
|
||||
|
||||
elif any(q in user_text_lower for q in ["what", "who", "where", "when", "why", "how"]):
|
||||
return "I'm running in fallback mode and can't answer complex questions. Please try again when the main language model is available."
|
||||
|
||||
else:
|
||||
return "I understand you said something about that. Unfortunately, I'm running in fallback mode with limited capabilities. Please try again later when the main model is available."
|
||||
|
||||
def generate_audio_response(text, conversation_segments):
|
||||
"""Generate audio response using CSM"""
|
||||
try:
|
||||
# Use the last few conversation segments as context
|
||||
context_segments = conversation_segments[-4:] if len(conversation_segments) > 4 else conversation_segments
|
||||
|
||||
@@ -244,6 +346,10 @@ def generate_audio_response(text, conversation_segments):
|
||||
)
|
||||
|
||||
return audio
|
||||
except Exception as e:
|
||||
print(f"Error generating audio: {e}")
|
||||
# Return silence as fallback
|
||||
return torch.zeros(csm_generator.sample_rate * 3) # 3 seconds of silence
|
||||
|
||||
if __name__ == '__main__':
|
||||
# Ensure the existing index.html file is in the correct location
|
||||
@@ -253,4 +359,11 @@ if __name__ == '__main__':
|
||||
if os.path.exists('index.html') and not os.path.exists('templates/index.html'):
|
||||
os.rename('index.html', 'templates/index.html')
|
||||
|
||||
# Load models asynchronously before starting the server
|
||||
print("Starting model loading...")
|
||||
# In a production environment, you could load models in a separate thread
|
||||
load_models()
|
||||
|
||||
# Start the server
|
||||
print("Starting Flask SocketIO server...")
|
||||
socketio.run(app, host='0.0.0.0', port=5000, debug=False)
|
||||
Reference in New Issue
Block a user