Merge branch 'main' of https://github.com/GamerBoss101/HooHacks-12
This commit is contained in:
@@ -16,6 +16,8 @@ from flask_socketio import SocketIO, emit, disconnect
|
|||||||
from generator import load_csm_1b, Segment
|
from generator import load_csm_1b, Segment
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from threading import Lock
|
from threading import Lock
|
||||||
|
from transformers import pipeline
|
||||||
|
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
|
||||||
|
|
||||||
# Configure logging
|
# Configure logging
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
@@ -84,29 +86,76 @@ logger.info("Initialized simple speech detector")
|
|||||||
|
|
||||||
# Model Loading Functions
|
# Model Loading Functions
|
||||||
def load_speech_models():
|
def load_speech_models():
|
||||||
"""Load speech generation model"""
|
"""Load speech generation and recognition models"""
|
||||||
# Load speech generation model (Sesame CSM)
|
# Load CSM (existing code)
|
||||||
try:
|
|
||||||
logger.info(f"Loading Sesame CSM model on {device}...")
|
|
||||||
generator = load_csm_1b(device=device)
|
generator = load_csm_1b(device=device)
|
||||||
logger.info("Sesame CSM model loaded successfully")
|
|
||||||
except Exception as e:
|
# Load Whisper model for speech recognition
|
||||||
logger.error(f"Error loading Sesame CSM on {device}: {e}")
|
|
||||||
if device == "cuda":
|
|
||||||
try:
|
try:
|
||||||
logger.info("Trying to load Sesame CSM on CPU instead...")
|
logger.info(f"Loading speech recognition model on {device}...")
|
||||||
generator = load_csm_1b(device="cpu")
|
speech_recognizer = pipeline("automatic-speech-recognition",
|
||||||
logger.info("Sesame CSM model loaded on CPU successfully")
|
model="openai/whisper-small",
|
||||||
except Exception as cpu_error:
|
device=device)
|
||||||
logger.critical(f"Failed to load speech synthesis model: {cpu_error}")
|
logger.info("Speech recognition model loaded successfully")
|
||||||
raise RuntimeError("Failed to load speech synthesis model")
|
except Exception as e:
|
||||||
else:
|
logger.error(f"Error loading speech recognition model: {e}")
|
||||||
raise RuntimeError("Failed to load speech synthesis model on any device")
|
speech_recognizer = None
|
||||||
|
|
||||||
return generator
|
return generator, speech_recognizer
|
||||||
|
|
||||||
# Load speech model
|
# Unpack both models
|
||||||
generator = load_speech_models()
|
generator, speech_recognizer = load_speech_models()
|
||||||
|
|
||||||
|
# Initialize Llama 3.2 model for conversation responses
|
||||||
|
def load_llm_model():
|
||||||
|
"""Load Llama 3.2 model for generating text responses"""
|
||||||
|
try:
|
||||||
|
logger.info("Loading Llama 3.2 model for conversational responses...")
|
||||||
|
model_id = "meta-llama/Llama-3.2-1B-Instruct"
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||||
|
|
||||||
|
# Determine compute device for LLM
|
||||||
|
llm_device = "cpu" # Default to CPU for LLM
|
||||||
|
|
||||||
|
# Use CUDA if available and there's enough VRAM
|
||||||
|
if device == "cuda" and torch.cuda.is_available():
|
||||||
|
try:
|
||||||
|
free_mem = torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_allocated(0)
|
||||||
|
# If we have at least 2GB free, use CUDA for LLM
|
||||||
|
if free_mem > 2 * 1024 * 1024 * 1024:
|
||||||
|
llm_device = "cuda"
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
logger.info(f"Using {llm_device} for Llama 3.2 model")
|
||||||
|
|
||||||
|
# Load the model with lower precision for efficiency
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
model_id,
|
||||||
|
torch_dtype=torch.float16 if llm_device == "cuda" else torch.float32,
|
||||||
|
device_map=llm_device
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a pipeline for easier inference
|
||||||
|
llm = pipeline(
|
||||||
|
"text-generation",
|
||||||
|
model=model,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
max_length=512,
|
||||||
|
do_sample=True,
|
||||||
|
temperature=0.7,
|
||||||
|
top_p=0.9,
|
||||||
|
repetition_penalty=1.1
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("Llama 3.2 model loaded successfully")
|
||||||
|
return llm
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error loading Llama 3.2 model: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Load the LLM model
|
||||||
|
llm = load_llm_model()
|
||||||
|
|
||||||
# Set up Flask and Socket.IO
|
# Set up Flask and Socket.IO
|
||||||
app = Flask(__name__)
|
app = Flask(__name__)
|
||||||
@@ -249,26 +298,116 @@ def encode_audio_data(audio_tensor: torch.Tensor) -> str:
|
|||||||
return f"data:audio/wav;base64,{base64.b64encode(buf.read()).decode('utf-8')}"
|
return f"data:audio/wav;base64,{base64.b64encode(buf.read()).decode('utf-8')}"
|
||||||
|
|
||||||
def process_speech(audio_tensor: torch.Tensor, client_id: str) -> str:
|
def process_speech(audio_tensor: torch.Tensor, client_id: str) -> str:
|
||||||
"""Process speech and return a simple response"""
|
"""Process speech with speech recognition"""
|
||||||
# In this simplified version, we'll just check if there's sound
|
if not speech_recognizer:
|
||||||
# and provide basic responses instead of doing actual speech recognition
|
# Fallback to basic detection if model failed to load
|
||||||
|
return detect_speech_energy(audio_tensor)
|
||||||
|
|
||||||
if speech_detector and speech_detector.detect_speech(audio_tensor, generator.sample_rate):
|
try:
|
||||||
# Generate a response based on audio energy
|
# Save audio to temp file for Whisper
|
||||||
|
temp_path = os.path.join(base_dir, f"temp_{time.time()}.wav")
|
||||||
|
torchaudio.save(temp_path, audio_tensor.unsqueeze(0).cpu(), generator.sample_rate)
|
||||||
|
|
||||||
|
# Perform speech recognition
|
||||||
|
result = speech_recognizer(temp_path)
|
||||||
|
transcription = result["text"]
|
||||||
|
|
||||||
|
# Clean up temp file
|
||||||
|
if os.path.exists(temp_path):
|
||||||
|
os.remove(temp_path)
|
||||||
|
|
||||||
|
# Return empty string if no speech detected
|
||||||
|
if not transcription or transcription.isspace():
|
||||||
|
return "I didn't detect any speech. Could you please try again?"
|
||||||
|
|
||||||
|
return transcription
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Speech recognition error: {e}")
|
||||||
|
return "Sorry, I couldn't understand what you said. Could you try again?"
|
||||||
|
|
||||||
|
def detect_speech_energy(audio_tensor: torch.Tensor) -> str:
|
||||||
|
"""Basic speech detection based on audio energy levels"""
|
||||||
|
# Calculate audio energy
|
||||||
energy = torch.mean(torch.abs(audio_tensor)).item()
|
energy = torch.mean(torch.abs(audio_tensor)).item()
|
||||||
|
|
||||||
|
logger.debug(f"Audio energy detected: {energy:.6f}")
|
||||||
|
|
||||||
|
# Generate response based on energy level
|
||||||
if energy > 0.1: # Louder speech
|
if energy > 0.1: # Louder speech
|
||||||
return "I heard you speaking clearly. How can I help you today?"
|
return "I heard you speaking clearly. How can I help you today?"
|
||||||
elif energy > 0.05: # Moderate speech
|
elif energy > 0.05: # Moderate speech
|
||||||
return "I heard you say something. Could you please repeat that?"
|
return "I heard you say something. Could you please repeat that?"
|
||||||
else: # Soft speech
|
elif energy > 0.02: # Soft speech
|
||||||
return "I detected some speech, but it was quite soft. Could you speak up a bit?"
|
return "I detected some speech, but it was quite soft. Could you speak up a bit?"
|
||||||
else:
|
else: # Very soft or no speech
|
||||||
return "I didn't detect any speech. Could you please try again?"
|
return "I didn't detect any speech. Could you please try again?"
|
||||||
|
|
||||||
def generate_response(text: str, conversation_history: List[Segment]) -> str:
|
def generate_response(text: str, conversation_history: List[Segment]) -> str:
|
||||||
"""Generate a contextual response based on the transcribed text"""
|
"""Generate a contextual response based on the transcribed text using Llama 3.2"""
|
||||||
# Simple response logic - can be replaced with a more sophisticated LLM
|
# If LLM is not available, use simple responses
|
||||||
|
if llm is None:
|
||||||
|
return generate_simple_response(text)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Create a conversational prompt based on history
|
||||||
|
# Format: recent conversation turns (up to 4) + current user input
|
||||||
|
history_str = ""
|
||||||
|
|
||||||
|
# Add up to 4 recent conversation turns (excluding the current one)
|
||||||
|
recent_segments = [
|
||||||
|
seg for seg in conversation_history[-8:]
|
||||||
|
if seg.text and not seg.text.isspace()
|
||||||
|
]
|
||||||
|
|
||||||
|
for i, segment in enumerate(recent_segments):
|
||||||
|
speaker_name = "User" if segment.speaker == 0 else "Assistant"
|
||||||
|
history_str += f"{speaker_name}: {segment.text}\n"
|
||||||
|
|
||||||
|
# Construct the prompt for Llama 3.2
|
||||||
|
prompt = f"""<|system|>
|
||||||
|
You are Sesame, a helpful, friendly and concise voice assistant.
|
||||||
|
Keep your responses conversational, natural, and to the point.
|
||||||
|
Respond to the user's latest message in the context of the conversation.
|
||||||
|
<|end|>
|
||||||
|
|
||||||
|
{history_str}
|
||||||
|
User: {text}
|
||||||
|
Assistant:"""
|
||||||
|
|
||||||
|
logger.debug(f"LLM Prompt: {prompt}")
|
||||||
|
|
||||||
|
# Generate response with the LLM
|
||||||
|
result = llm(
|
||||||
|
prompt,
|
||||||
|
max_new_tokens=150,
|
||||||
|
do_sample=True,
|
||||||
|
temperature=0.7,
|
||||||
|
top_p=0.9,
|
||||||
|
repetition_penalty=1.1
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract the generated text
|
||||||
|
response = result[0]["generated_text"]
|
||||||
|
|
||||||
|
# Extract just the Assistant's response (after the prompt)
|
||||||
|
response = response.split("Assistant:")[-1].strip()
|
||||||
|
|
||||||
|
# Clean up and ensure it's not too long for TTS
|
||||||
|
response = response.split("\n")[0].strip()
|
||||||
|
if len(response) > 200:
|
||||||
|
response = response[:197] + "..."
|
||||||
|
|
||||||
|
logger.info(f"LLM response: {response}")
|
||||||
|
return response
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error generating LLM response: {e}")
|
||||||
|
# Fall back to simple responses
|
||||||
|
return generate_simple_response(text)
|
||||||
|
|
||||||
|
def generate_simple_response(text: str) -> str:
|
||||||
|
"""Generate a simple rule-based response as fallback"""
|
||||||
responses = {
|
responses = {
|
||||||
"hello": "Hello there! How can I help you today?",
|
"hello": "Hello there! How can I help you today?",
|
||||||
"hi": "Hi there! What can I do for you?",
|
"hi": "Hi there! What can I do for you?",
|
||||||
@@ -295,7 +434,7 @@ def generate_response(text: str, conversation_history: List[Segment]) -> str:
|
|||||||
elif len(text) < 10:
|
elif len(text) < 10:
|
||||||
return "Thanks for your message. Could you elaborate a bit more?"
|
return "Thanks for your message. Could you elaborate a bit more?"
|
||||||
else:
|
else:
|
||||||
return f"I heard you speaking. That's interesting! Can you tell me more about that?"
|
return f"I heard you say something about that. Can you tell me more?"
|
||||||
|
|
||||||
# Flask Routes
|
# Flask Routes
|
||||||
@app.route('/')
|
@app.route('/')
|
||||||
|
|||||||
Reference in New Issue
Block a user