Frontend Fixed
This commit is contained in:
392
Backend/api/socket_handlers.py
Normal file
392
Backend/api/socket_handlers.py
Normal file
@@ -0,0 +1,392 @@
|
||||
import os
|
||||
import io
|
||||
import base64
|
||||
import time
|
||||
import threading
|
||||
import queue
|
||||
import tempfile
|
||||
import gc
|
||||
import logging
|
||||
import traceback
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import torch
|
||||
import torchaudio
|
||||
import numpy as np
|
||||
from flask import request
|
||||
from flask_socketio import emit
|
||||
|
||||
# Import conversation model
|
||||
from generator import Segment
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Conversation data structure
|
||||
class Conversation:
|
||||
def __init__(self, session_id):
|
||||
self.session_id = session_id
|
||||
self.segments: List[Segment] = []
|
||||
self.current_speaker = 0
|
||||
self.ai_speaker_id = 1 # Default AI speaker ID
|
||||
self.last_activity = time.time()
|
||||
self.is_processing = False
|
||||
|
||||
def add_segment(self, text, speaker, audio):
|
||||
segment = Segment(text=text, speaker=speaker, audio=audio)
|
||||
self.segments.append(segment)
|
||||
self.last_activity = time.time()
|
||||
return segment
|
||||
|
||||
def get_context(self, max_segments=10):
|
||||
"""Return the most recent segments for context"""
|
||||
return self.segments[-max_segments:] if self.segments else []
|
||||
|
||||
def register_handlers(socketio, app, models, active_conversations, user_queues, processing_threads, DEVICE):
|
||||
"""Register Socket.IO event handlers"""
|
||||
|
||||
@socketio.on('connect')
|
||||
def handle_connect(auth=None):
|
||||
"""Handle client connection"""
|
||||
session_id = request.sid
|
||||
logger.info(f"Client connected: {session_id}")
|
||||
|
||||
# Initialize conversation data
|
||||
if session_id not in active_conversations:
|
||||
active_conversations[session_id] = Conversation(session_id)
|
||||
user_queues[session_id] = queue.Queue()
|
||||
processing_threads[session_id] = threading.Thread(
|
||||
target=process_audio_queue,
|
||||
args=(session_id, user_queues[session_id], app, socketio, models, active_conversations, DEVICE),
|
||||
daemon=True
|
||||
)
|
||||
processing_threads[session_id].start()
|
||||
|
||||
emit('connection_status', {'status': 'connected'})
|
||||
|
||||
@socketio.on('disconnect')
|
||||
def handle_disconnect(reason=None):
|
||||
"""Handle client disconnection"""
|
||||
session_id = request.sid
|
||||
logger.info(f"Client disconnected: {session_id}. Reason: {reason}")
|
||||
|
||||
# Cleanup
|
||||
if session_id in active_conversations:
|
||||
# Mark for deletion rather than immediately removing
|
||||
# as the processing thread might still be accessing it
|
||||
active_conversations[session_id].is_processing = False
|
||||
user_queues[session_id].put(None) # Signal thread to terminate
|
||||
|
||||
@socketio.on('audio_data')
|
||||
def handle_audio_data(data):
|
||||
"""Handle incoming audio data"""
|
||||
session_id = request.sid
|
||||
logger.info(f"Received audio data from {session_id}")
|
||||
|
||||
# Check if the models are loaded
|
||||
if models.generator is None or models.whisperx_model is None or models.llm is None:
|
||||
emit('error', {'message': 'Models still loading, please wait'})
|
||||
return
|
||||
|
||||
# Check if we're already processing for this session
|
||||
if session_id in active_conversations and active_conversations[session_id].is_processing:
|
||||
emit('error', {'message': 'Still processing previous audio, please wait'})
|
||||
return
|
||||
|
||||
# Add to processing queue
|
||||
if session_id in user_queues:
|
||||
user_queues[session_id].put(data)
|
||||
else:
|
||||
emit('error', {'message': 'Session not initialized, please refresh the page'})
|
||||
|
||||
def process_audio_queue(session_id, q, app, socketio, models, active_conversations, DEVICE):
|
||||
"""Background thread to process audio chunks for a session"""
|
||||
logger.info(f"Started processing thread for session: {session_id}")
|
||||
|
||||
try:
|
||||
while session_id in active_conversations:
|
||||
try:
|
||||
# Get the next audio chunk with a timeout
|
||||
data = q.get(timeout=120)
|
||||
if data is None: # Termination signal
|
||||
break
|
||||
|
||||
# Process the audio and generate a response
|
||||
process_audio_and_respond(session_id, data, app, socketio, models, active_conversations, DEVICE)
|
||||
|
||||
except queue.Empty:
|
||||
# Timeout, check if session is still valid
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing audio for {session_id}: {str(e)}")
|
||||
# Create an app context for the socket emit
|
||||
with app.app_context():
|
||||
socketio.emit('error', {'message': str(e)}, room=session_id)
|
||||
finally:
|
||||
logger.info(f"Ending processing thread for session: {session_id}")
|
||||
# Clean up when thread is done
|
||||
with app.app_context():
|
||||
if session_id in active_conversations:
|
||||
del active_conversations[session_id]
|
||||
if session_id in user_queues:
|
||||
del user_queues[session_id]
|
||||
|
||||
def process_audio_and_respond(session_id, data, app, socketio, models, active_conversations, DEVICE):
|
||||
"""Process audio data and generate a response using WhisperX"""
|
||||
if models.generator is None or models.whisperx_model is None or models.llm is None:
|
||||
logger.warning("Models not yet loaded!")
|
||||
with app.app_context():
|
||||
socketio.emit('error', {'message': 'Models still loading, please wait'}, room=session_id)
|
||||
return
|
||||
|
||||
logger.info(f"Processing audio for session {session_id}")
|
||||
conversation = active_conversations[session_id]
|
||||
|
||||
try:
|
||||
# Set processing flag
|
||||
conversation.is_processing = True
|
||||
|
||||
# Process base64 audio data
|
||||
audio_data = data['audio']
|
||||
speaker_id = data['speaker']
|
||||
logger.info(f"Received audio from speaker {speaker_id}")
|
||||
|
||||
# Convert from base64 to WAV
|
||||
try:
|
||||
audio_bytes = base64.b64decode(audio_data.split(',')[1])
|
||||
logger.info(f"Decoded audio bytes: {len(audio_bytes)} bytes")
|
||||
except Exception as e:
|
||||
logger.error(f"Error decoding base64 audio: {str(e)}")
|
||||
raise
|
||||
|
||||
# Save to temporary file for processing
|
||||
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as temp_file:
|
||||
temp_file.write(audio_bytes)
|
||||
temp_path = temp_file.name
|
||||
|
||||
try:
|
||||
# Notify client that transcription is starting
|
||||
with app.app_context():
|
||||
socketio.emit('processing_status', {'status': 'transcribing'}, room=session_id)
|
||||
|
||||
# Load audio using WhisperX
|
||||
import whisperx
|
||||
audio = whisperx.load_audio(temp_path)
|
||||
|
||||
# Check audio length and add a warning for short clips
|
||||
audio_length = len(audio) / 16000 # assuming 16kHz sample rate
|
||||
if audio_length < 1.0:
|
||||
logger.warning(f"Audio is very short ({audio_length:.2f}s), may affect transcription quality")
|
||||
|
||||
# Transcribe using WhisperX
|
||||
batch_size = 16 # adjust based on your GPU memory
|
||||
logger.info("Running WhisperX transcription...")
|
||||
|
||||
# Handle the warning about audio being shorter than 30s by suppressing it
|
||||
import warnings
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings("ignore", message="audio is shorter than 30s")
|
||||
result = models.whisperx_model.transcribe(audio, batch_size=batch_size)
|
||||
|
||||
# Get the detected language
|
||||
language_code = result["language"]
|
||||
logger.info(f"Detected language: {language_code}")
|
||||
|
||||
# Check if alignment model needs to be loaded or updated
|
||||
if models.whisperx_align_model is None or language_code != models.last_language:
|
||||
# Clean up old models if they exist
|
||||
if models.whisperx_align_model is not None:
|
||||
del models.whisperx_align_model
|
||||
del models.whisperx_align_metadata
|
||||
if DEVICE == "cuda":
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Load new alignment model for the detected language
|
||||
logger.info(f"Loading alignment model for language: {language_code}")
|
||||
models.whisperx_align_model, models.whisperx_align_metadata = whisperx.load_align_model(
|
||||
language_code=language_code, device=DEVICE
|
||||
)
|
||||
models.last_language = language_code
|
||||
|
||||
# Align the transcript to get word-level timestamps
|
||||
if result["segments"] and len(result["segments"]) > 0:
|
||||
logger.info("Aligning transcript...")
|
||||
result = whisperx.align(
|
||||
result["segments"],
|
||||
models.whisperx_align_model,
|
||||
models.whisperx_align_metadata,
|
||||
audio,
|
||||
DEVICE,
|
||||
return_char_alignments=False
|
||||
)
|
||||
|
||||
# Process the segments for better output
|
||||
for segment in result["segments"]:
|
||||
# Round timestamps for better display
|
||||
segment["start"] = round(segment["start"], 2)
|
||||
segment["end"] = round(segment["end"], 2)
|
||||
# Add a confidence score if not present
|
||||
if "confidence" not in segment:
|
||||
segment["confidence"] = 1.0 # Default confidence
|
||||
|
||||
# Extract the full text from all segments
|
||||
user_text = ' '.join([segment['text'] for segment in result['segments']])
|
||||
|
||||
# If no text was recognized, don't process further
|
||||
if not user_text or len(user_text.strip()) == 0:
|
||||
with app.app_context():
|
||||
socketio.emit('error', {'message': 'No speech detected'}, room=session_id)
|
||||
return
|
||||
|
||||
logger.info(f"Transcription: {user_text}")
|
||||
|
||||
# Load audio for CSM input
|
||||
waveform, sample_rate = torchaudio.load(temp_path)
|
||||
|
||||
# Normalize to mono if needed
|
||||
if waveform.shape[0] > 1:
|
||||
waveform = torch.mean(waveform, dim=0, keepdim=True)
|
||||
|
||||
# Resample to the CSM sample rate if needed
|
||||
if sample_rate != models.generator.sample_rate:
|
||||
waveform = torchaudio.functional.resample(
|
||||
waveform,
|
||||
orig_freq=sample_rate,
|
||||
new_freq=models.generator.sample_rate
|
||||
)
|
||||
|
||||
# Add the user's message to conversation history
|
||||
user_segment = conversation.add_segment(
|
||||
text=user_text,
|
||||
speaker=speaker_id,
|
||||
audio=waveform.squeeze()
|
||||
)
|
||||
|
||||
# Send transcription to client with detailed segments
|
||||
with app.app_context():
|
||||
socketio.emit('transcription', {
|
||||
'text': user_text,
|
||||
'speaker': speaker_id,
|
||||
'segments': result['segments'] # Include the detailed segments with timestamps
|
||||
}, room=session_id)
|
||||
|
||||
# Generate AI response using Llama
|
||||
with app.app_context():
|
||||
socketio.emit('processing_status', {'status': 'generating'}, room=session_id)
|
||||
|
||||
# Create prompt from conversation history
|
||||
conversation_history = ""
|
||||
for segment in conversation.segments[-5:]: # Last 5 segments for context
|
||||
role = "User" if segment.speaker == 0 else "Assistant"
|
||||
conversation_history += f"{role}: {segment.text}\n"
|
||||
|
||||
# Add final prompt
|
||||
prompt = f"{conversation_history}Assistant: "
|
||||
|
||||
# Generate response with Llama
|
||||
try:
|
||||
# Ensure pad token is set
|
||||
if models.tokenizer.pad_token is None:
|
||||
models.tokenizer.pad_token = models.tokenizer.eos_token
|
||||
|
||||
input_tokens = models.tokenizer(
|
||||
prompt,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
return_attention_mask=True
|
||||
)
|
||||
input_ids = input_tokens.input_ids.to(DEVICE)
|
||||
attention_mask = input_tokens.attention_mask.to(DEVICE)
|
||||
|
||||
with torch.no_grad():
|
||||
generated_ids = models.llm.generate(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
max_new_tokens=100,
|
||||
temperature=0.7,
|
||||
top_p=0.9,
|
||||
do_sample=True,
|
||||
pad_token_id=models.tokenizer.eos_token_id
|
||||
)
|
||||
|
||||
# Decode the response
|
||||
response_text = models.tokenizer.decode(
|
||||
generated_ids[0][input_ids.shape[1]:],
|
||||
skip_special_tokens=True
|
||||
).strip()
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating response: {str(e)}")
|
||||
logger.error(traceback.format_exc())
|
||||
response_text = "I'm sorry, I encountered an error while processing your request."
|
||||
|
||||
# Synthesize speech
|
||||
with app.app_context():
|
||||
socketio.emit('processing_status', {'status': 'synthesizing'}, room=session_id)
|
||||
|
||||
# Start sending the audio response
|
||||
socketio.emit('audio_response_start', {
|
||||
'text': response_text,
|
||||
'total_chunks': 1,
|
||||
'chunk_index': 0
|
||||
}, room=session_id)
|
||||
|
||||
# Define AI speaker ID
|
||||
ai_speaker_id = conversation.ai_speaker_id
|
||||
|
||||
# Generate audio
|
||||
audio_tensor = models.generator.generate(
|
||||
text=response_text,
|
||||
speaker=ai_speaker_id,
|
||||
context=conversation.get_context(),
|
||||
max_audio_length_ms=10_000,
|
||||
temperature=0.9
|
||||
)
|
||||
|
||||
# Add AI response to conversation history
|
||||
ai_segment = conversation.add_segment(
|
||||
text=response_text,
|
||||
speaker=ai_speaker_id,
|
||||
audio=audio_tensor
|
||||
)
|
||||
|
||||
# Convert audio to WAV format
|
||||
with io.BytesIO() as wav_io:
|
||||
torchaudio.save(
|
||||
wav_io,
|
||||
audio_tensor.unsqueeze(0).cpu(),
|
||||
models.generator.sample_rate,
|
||||
format="wav"
|
||||
)
|
||||
wav_io.seek(0)
|
||||
wav_data = wav_io.read()
|
||||
|
||||
# Convert WAV data to base64
|
||||
audio_base64 = f"data:audio/wav;base64,{base64.b64encode(wav_data).decode('utf-8')}"
|
||||
|
||||
# Send audio chunk to client
|
||||
with app.app_context():
|
||||
socketio.emit('audio_response_chunk', {
|
||||
'chunk': audio_base64,
|
||||
'chunk_index': 0,
|
||||
'total_chunks': 1,
|
||||
'is_last': True
|
||||
}, room=session_id)
|
||||
|
||||
# Signal completion
|
||||
socketio.emit('audio_response_complete', {
|
||||
'text': response_text
|
||||
}, room=session_id)
|
||||
|
||||
finally:
|
||||
# Clean up temp file
|
||||
if os.path.exists(temp_path):
|
||||
os.unlink(temp_path)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing audio: {str(e)}")
|
||||
logger.error(traceback.format_exc())
|
||||
with app.app_context():
|
||||
socketio.emit('error', {'message': f'Error: {str(e)}'}, room=session_id)
|
||||
finally:
|
||||
# Reset processing flag
|
||||
conversation.is_processing = False
|
||||
Reference in New Issue
Block a user