Demo Update 8

This commit is contained in:
2025-03-30 00:46:01 -04:00
parent 8592257cdc
commit e83162c347

View File

@@ -16,6 +16,8 @@ from flask_socketio import SocketIO, emit, disconnect
from generator import load_csm_1b, Segment
from collections import deque
from threading import Lock
from transformers import pipeline
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
# Configure logging
logging.basicConfig(
@@ -84,29 +86,76 @@ logger.info("Initialized simple speech detector")
# Model Loading Functions
def load_speech_models():
"""Load speech generation model"""
# Load speech generation model (Sesame CSM)
try:
logger.info(f"Loading Sesame CSM model on {device}...")
generator = load_csm_1b(device=device)
logger.info("Sesame CSM model loaded successfully")
except Exception as e:
logger.error(f"Error loading Sesame CSM on {device}: {e}")
if device == "cuda":
try:
logger.info("Trying to load Sesame CSM on CPU instead...")
generator = load_csm_1b(device="cpu")
logger.info("Sesame CSM model loaded on CPU successfully")
except Exception as cpu_error:
logger.critical(f"Failed to load speech synthesis model: {cpu_error}")
raise RuntimeError("Failed to load speech synthesis model")
else:
raise RuntimeError("Failed to load speech synthesis model on any device")
"""Load speech generation and recognition models"""
# Load CSM (existing code)
generator = load_csm_1b(device=device)
return generator
# Load Whisper model for speech recognition
try:
logger.info(f"Loading speech recognition model on {device}...")
speech_recognizer = pipeline("automatic-speech-recognition",
model="openai/whisper-small",
device=device)
logger.info("Speech recognition model loaded successfully")
except Exception as e:
logger.error(f"Error loading speech recognition model: {e}")
speech_recognizer = None
return generator, speech_recognizer
# Load speech model
generator = load_speech_models()
# Unpack both models
generator, speech_recognizer = load_speech_models()
# Initialize Llama 3.2 model for conversation responses
def load_llm_model():
"""Load Llama 3.2 model for generating text responses"""
try:
logger.info("Loading Llama 3.2 model for conversational responses...")
model_id = "meta-llama/Llama-3.2-1B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_id)
# Determine compute device for LLM
llm_device = "cpu" # Default to CPU for LLM
# Use CUDA if available and there's enough VRAM
if device == "cuda" and torch.cuda.is_available():
try:
free_mem = torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_allocated(0)
# If we have at least 2GB free, use CUDA for LLM
if free_mem > 2 * 1024 * 1024 * 1024:
llm_device = "cuda"
except:
pass
logger.info(f"Using {llm_device} for Llama 3.2 model")
# Load the model with lower precision for efficiency
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float16 if llm_device == "cuda" else torch.float32,
device_map=llm_device
)
# Create a pipeline for easier inference
llm = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
max_length=512,
do_sample=True,
temperature=0.7,
top_p=0.9,
repetition_penalty=1.1
)
logger.info("Llama 3.2 model loaded successfully")
return llm
except Exception as e:
logger.error(f"Error loading Llama 3.2 model: {e}")
return None
# Load the LLM model
llm = load_llm_model()
# Set up Flask and Socket.IO
app = Flask(__name__)
@@ -249,26 +298,116 @@ def encode_audio_data(audio_tensor: torch.Tensor) -> str:
return f"data:audio/wav;base64,{base64.b64encode(buf.read()).decode('utf-8')}"
def process_speech(audio_tensor: torch.Tensor, client_id: str) -> str:
"""Process speech and return a simple response"""
# In this simplified version, we'll just check if there's sound
# and provide basic responses instead of doing actual speech recognition
"""Process speech with speech recognition"""
if not speech_recognizer:
# Fallback to basic detection if model failed to load
return detect_speech_energy(audio_tensor)
if speech_detector and speech_detector.detect_speech(audio_tensor, generator.sample_rate):
# Generate a response based on audio energy
energy = torch.mean(torch.abs(audio_tensor)).item()
try:
# Save audio to temp file for Whisper
temp_path = os.path.join(base_dir, f"temp_{time.time()}.wav")
torchaudio.save(temp_path, audio_tensor.unsqueeze(0).cpu(), generator.sample_rate)
if energy > 0.1: # Louder speech
return "I heard you speaking clearly. How can I help you today?"
elif energy > 0.05: # Moderate speech
return "I heard you say something. Could you please repeat that?"
else: # Soft speech
return "I detected some speech, but it was quite soft. Could you speak up a bit?"
else:
# Perform speech recognition
result = speech_recognizer(temp_path)
transcription = result["text"]
# Clean up temp file
if os.path.exists(temp_path):
os.remove(temp_path)
# Return empty string if no speech detected
if not transcription or transcription.isspace():
return "I didn't detect any speech. Could you please try again?"
return transcription
except Exception as e:
logger.error(f"Speech recognition error: {e}")
return "Sorry, I couldn't understand what you said. Could you try again?"
def detect_speech_energy(audio_tensor: torch.Tensor) -> str:
"""Basic speech detection based on audio energy levels"""
# Calculate audio energy
energy = torch.mean(torch.abs(audio_tensor)).item()
logger.debug(f"Audio energy detected: {energy:.6f}")
# Generate response based on energy level
if energy > 0.1: # Louder speech
return "I heard you speaking clearly. How can I help you today?"
elif energy > 0.05: # Moderate speech
return "I heard you say something. Could you please repeat that?"
elif energy > 0.02: # Soft speech
return "I detected some speech, but it was quite soft. Could you speak up a bit?"
else: # Very soft or no speech
return "I didn't detect any speech. Could you please try again?"
def generate_response(text: str, conversation_history: List[Segment]) -> str:
"""Generate a contextual response based on the transcribed text"""
# Simple response logic - can be replaced with a more sophisticated LLM
"""Generate a contextual response based on the transcribed text using Llama 3.2"""
# If LLM is not available, use simple responses
if llm is None:
return generate_simple_response(text)
try:
# Create a conversational prompt based on history
# Format: recent conversation turns (up to 4) + current user input
history_str = ""
# Add up to 4 recent conversation turns (excluding the current one)
recent_segments = [
seg for seg in conversation_history[-8:]
if seg.text and not seg.text.isspace()
]
for i, segment in enumerate(recent_segments):
speaker_name = "User" if segment.speaker == 0 else "Assistant"
history_str += f"{speaker_name}: {segment.text}\n"
# Construct the prompt for Llama 3.2
prompt = f"""<|system|>
You are Sesame, a helpful, friendly and concise voice assistant.
Keep your responses conversational, natural, and to the point.
Respond to the user's latest message in the context of the conversation.
<|end|>
{history_str}
User: {text}
Assistant:"""
logger.debug(f"LLM Prompt: {prompt}")
# Generate response with the LLM
result = llm(
prompt,
max_new_tokens=150,
do_sample=True,
temperature=0.7,
top_p=0.9,
repetition_penalty=1.1
)
# Extract the generated text
response = result[0]["generated_text"]
# Extract just the Assistant's response (after the prompt)
response = response.split("Assistant:")[-1].strip()
# Clean up and ensure it's not too long for TTS
response = response.split("\n")[0].strip()
if len(response) > 200:
response = response[:197] + "..."
logger.info(f"LLM response: {response}")
return response
except Exception as e:
logger.error(f"Error generating LLM response: {e}")
# Fall back to simple responses
return generate_simple_response(text)
def generate_simple_response(text: str) -> str:
"""Generate a simple rule-based response as fallback"""
responses = {
"hello": "Hello there! How can I help you today?",
"hi": "Hi there! What can I do for you?",
@@ -295,7 +434,7 @@ def generate_response(text: str, conversation_history: List[Segment]) -> str:
elif len(text) < 10:
return "Thanks for your message. Could you elaborate a bit more?"
else:
return f"I heard you speaking. That's interesting! Can you tell me more about that?"
return f"I heard you say something about that. Can you tell me more?"
# Flask Routes
@app.route('/')