backend restart
This commit is contained in:
@@ -1,136 +0,0 @@
|
|||||||
import os
|
|
||||||
import logging
|
|
||||||
import threading
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from flask import Flask
|
|
||||||
from flask_socketio import SocketIO
|
|
||||||
from flask_cors import CORS
|
|
||||||
|
|
||||||
# Configure logging
|
|
||||||
logging.basicConfig(level=logging.INFO,
|
|
||||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
# Configure device
|
|
||||||
import torch
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
DEVICE = "cuda"
|
|
||||||
elif torch.backends.mps.is_available():
|
|
||||||
DEVICE = "mps"
|
|
||||||
else:
|
|
||||||
DEVICE = "cpu"
|
|
||||||
|
|
||||||
logger.info(f"Using device: {DEVICE}")
|
|
||||||
|
|
||||||
# Initialize Flask app
|
|
||||||
app = Flask(__name__, static_folder='../', static_url_path='')
|
|
||||||
CORS(app)
|
|
||||||
socketio = SocketIO(app, cors_allowed_origins="*", ping_timeout=120)
|
|
||||||
|
|
||||||
# Global variables for conversation state
|
|
||||||
active_conversations = {}
|
|
||||||
user_queues = {}
|
|
||||||
processing_threads = {}
|
|
||||||
|
|
||||||
# Model storage
|
|
||||||
@dataclass
|
|
||||||
class AppModels:
|
|
||||||
generator = None
|
|
||||||
tokenizer = None
|
|
||||||
llm = None
|
|
||||||
whisperx_model = None
|
|
||||||
whisperx_align_model = None
|
|
||||||
whisperx_align_metadata = None
|
|
||||||
last_language = None
|
|
||||||
|
|
||||||
models = AppModels()
|
|
||||||
|
|
||||||
def load_models():
|
|
||||||
"""Load all required models"""
|
|
||||||
from generator import load_csm_1b
|
|
||||||
import whisperx
|
|
||||||
import gc
|
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
||||||
global models
|
|
||||||
|
|
||||||
socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 0})
|
|
||||||
|
|
||||||
# CSM 1B loading
|
|
||||||
try:
|
|
||||||
socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 10, 'message': 'Loading CSM voice model'})
|
|
||||||
models.generator = load_csm_1b(device=DEVICE)
|
|
||||||
logger.info("CSM 1B model loaded successfully")
|
|
||||||
socketio.emit('model_status', {'model': 'csm', 'status': 'loaded'})
|
|
||||||
socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 33})
|
|
||||||
if DEVICE == "cuda":
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
except Exception as e:
|
|
||||||
import traceback
|
|
||||||
error_details = traceback.format_exc()
|
|
||||||
logger.error(f"Error loading CSM 1B model: {str(e)}\n{error_details}")
|
|
||||||
socketio.emit('model_status', {'model': 'csm', 'status': 'error', 'message': str(e)})
|
|
||||||
|
|
||||||
# WhisperX loading
|
|
||||||
try:
|
|
||||||
socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 40, 'message': 'Loading speech recognition model'})
|
|
||||||
# Use WhisperX for better transcription with timestamps
|
|
||||||
# Use compute_type based on device
|
|
||||||
compute_type = "float16" if DEVICE == "cuda" else "float32"
|
|
||||||
|
|
||||||
# Load the WhisperX model (smaller model for faster processing)
|
|
||||||
models.whisperx_model = whisperx.load_model("small", DEVICE, compute_type=compute_type)
|
|
||||||
|
|
||||||
logger.info("WhisperX model loaded successfully")
|
|
||||||
socketio.emit('model_status', {'model': 'asr', 'status': 'loaded'})
|
|
||||||
socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 66})
|
|
||||||
if DEVICE == "cuda":
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
except Exception as e:
|
|
||||||
import traceback
|
|
||||||
error_details = traceback.format_exc()
|
|
||||||
logger.error(f"Error loading WhisperX model: {str(e)}\n{error_details}")
|
|
||||||
socketio.emit('model_status', {'model': 'asr', 'status': 'error', 'message': str(e)})
|
|
||||||
|
|
||||||
# Llama loading
|
|
||||||
try:
|
|
||||||
socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 70, 'message': 'Loading language model'})
|
|
||||||
models.llm = AutoModelForCausalLM.from_pretrained(
|
|
||||||
"meta-llama/Llama-3.2-1B",
|
|
||||||
device_map=DEVICE,
|
|
||||||
torch_dtype=torch.bfloat16
|
|
||||||
)
|
|
||||||
models.tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B")
|
|
||||||
|
|
||||||
# Configure all special tokens
|
|
||||||
models.tokenizer.pad_token = models.tokenizer.eos_token
|
|
||||||
models.tokenizer.padding_side = "left" # For causal language modeling
|
|
||||||
|
|
||||||
# Inform the model about the pad token
|
|
||||||
if hasattr(models.llm.config, "pad_token_id") and models.llm.config.pad_token_id is None:
|
|
||||||
models.llm.config.pad_token_id = models.tokenizer.pad_token_id
|
|
||||||
|
|
||||||
logger.info("Llama 3.2 model loaded successfully")
|
|
||||||
socketio.emit('model_status', {'model': 'llm', 'status': 'loaded'})
|
|
||||||
socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 100, 'message': 'All models loaded successfully'})
|
|
||||||
socketio.emit('model_status', {'model': 'overall', 'status': 'loaded'})
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error loading Llama 3.2 model: {str(e)}")
|
|
||||||
socketio.emit('model_status', {'model': 'llm', 'status': 'error', 'message': str(e)})
|
|
||||||
|
|
||||||
# Load models in a background thread
|
|
||||||
threading.Thread(target=load_models, daemon=True).start()
|
|
||||||
|
|
||||||
# Import routes and socket handlers
|
|
||||||
from api.routes import register_routes
|
|
||||||
from api.socket_handlers import register_handlers
|
|
||||||
|
|
||||||
# Register routes and socket handlers
|
|
||||||
register_routes(app)
|
|
||||||
register_handlers(socketio, app, models, active_conversations, user_queues, processing_threads, DEVICE)
|
|
||||||
|
|
||||||
# Run server if executed directly
|
|
||||||
if __name__ == '__main__':
|
|
||||||
port = int(os.environ.get('PORT', 5000))
|
|
||||||
debug_mode = os.environ.get('DEBUG', 'False').lower() == 'true'
|
|
||||||
logger.info(f"Starting server on port {port} (debug={debug_mode})")
|
|
||||||
socketio.run(app, host='0.0.0.0', port=port, debug=debug_mode, allow_unsafe_werkzeug=True)
|
|
||||||
@@ -1,74 +0,0 @@
|
|||||||
import os
|
|
||||||
import torch
|
|
||||||
import psutil
|
|
||||||
from flask import send_from_directory, jsonify, request
|
|
||||||
|
|
||||||
def register_routes(app):
|
|
||||||
"""Register HTTP routes for the application"""
|
|
||||||
|
|
||||||
@app.route('/')
|
|
||||||
def index():
|
|
||||||
"""Serve the main application page"""
|
|
||||||
return send_from_directory(app.static_folder, 'index.html')
|
|
||||||
|
|
||||||
@app.route('/voice-chat.js')
|
|
||||||
def serve_js():
|
|
||||||
"""Serve the JavaScript file"""
|
|
||||||
return send_from_directory(app.static_folder, 'voice-chat.js')
|
|
||||||
|
|
||||||
@app.route('/api/status')
|
|
||||||
def system_status():
|
|
||||||
"""Return the system status"""
|
|
||||||
# Import here to avoid circular imports
|
|
||||||
from api.app import models, DEVICE
|
|
||||||
|
|
||||||
return jsonify({
|
|
||||||
"status": "ok",
|
|
||||||
"cuda_available": torch.cuda.is_available(),
|
|
||||||
"device": DEVICE,
|
|
||||||
"models": {
|
|
||||||
"generator": models.generator is not None,
|
|
||||||
"asr": models.whisperx_model is not None,
|
|
||||||
"llm": models.llm is not None
|
|
||||||
},
|
|
||||||
"versions": {
|
|
||||||
"transformers": "4.49.0", # Replace with actual version
|
|
||||||
"torch": torch.__version__
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
@app.route('/api/system_resources')
|
|
||||||
def system_resources():
|
|
||||||
"""Return system resource usage"""
|
|
||||||
# Import here to avoid circular imports
|
|
||||||
from api.app import active_conversations, DEVICE
|
|
||||||
|
|
||||||
# Get CPU usage
|
|
||||||
cpu_percent = psutil.cpu_percent(interval=0.1)
|
|
||||||
|
|
||||||
# Get memory usage
|
|
||||||
memory = psutil.virtual_memory()
|
|
||||||
memory_used_gb = memory.used / (1024 ** 3)
|
|
||||||
memory_total_gb = memory.total / (1024 ** 3)
|
|
||||||
memory_percent = memory.percent
|
|
||||||
|
|
||||||
# Get GPU memory if available
|
|
||||||
gpu_memory = {}
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
for i in range(torch.cuda.device_count()):
|
|
||||||
gpu_memory[f"gpu_{i}"] = {
|
|
||||||
"allocated": torch.cuda.memory_allocated(i) / (1024 ** 3),
|
|
||||||
"reserved": torch.cuda.memory_reserved(i) / (1024 ** 3),
|
|
||||||
"max_allocated": torch.cuda.max_memory_allocated(i) / (1024 ** 3)
|
|
||||||
}
|
|
||||||
|
|
||||||
return jsonify({
|
|
||||||
"cpu_percent": cpu_percent,
|
|
||||||
"memory": {
|
|
||||||
"used_gb": memory_used_gb,
|
|
||||||
"total_gb": memory_total_gb,
|
|
||||||
"percent": memory_percent
|
|
||||||
},
|
|
||||||
"gpu_memory": gpu_memory,
|
|
||||||
"active_sessions": len(active_conversations)
|
|
||||||
})
|
|
||||||
@@ -1,393 +0,0 @@
|
|||||||
import os
|
|
||||||
import io
|
|
||||||
import base64
|
|
||||||
import time
|
|
||||||
import threading
|
|
||||||
import queue
|
|
||||||
import tempfile
|
|
||||||
import gc
|
|
||||||
import logging
|
|
||||||
import traceback
|
|
||||||
from typing import Dict, List, Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torchaudio
|
|
||||||
import numpy as np
|
|
||||||
from flask import request
|
|
||||||
from flask_socketio import emit
|
|
||||||
|
|
||||||
# Import conversation model
|
|
||||||
from generator import Segment
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
# Conversation data structure
|
|
||||||
class Conversation:
|
|
||||||
def __init__(self, session_id):
|
|
||||||
self.session_id = session_id
|
|
||||||
self.segments: List[Segment] = []
|
|
||||||
self.current_speaker = 0
|
|
||||||
self.ai_speaker_id = 1 # Default AI speaker ID
|
|
||||||
self.last_activity = time.time()
|
|
||||||
self.is_processing = False
|
|
||||||
|
|
||||||
def add_segment(self, text, speaker, audio):
|
|
||||||
segment = Segment(text=text, speaker=speaker, audio=audio)
|
|
||||||
self.segments.append(segment)
|
|
||||||
self.last_activity = time.time()
|
|
||||||
return segment
|
|
||||||
|
|
||||||
def get_context(self, max_segments=10):
|
|
||||||
"""Return the most recent segments for context"""
|
|
||||||
return self.segments[-max_segments:] if self.segments else []
|
|
||||||
|
|
||||||
def register_handlers(socketio, app, models, active_conversations, user_queues, processing_threads, DEVICE):
|
|
||||||
"""Register Socket.IO event handlers"""
|
|
||||||
# No need for global references, just use the parameters directly
|
|
||||||
|
|
||||||
@socketio.on('connect')
|
|
||||||
def handle_connect(auth=None):
|
|
||||||
"""Handle client connection"""
|
|
||||||
session_id = request.sid
|
|
||||||
logger.info(f"Client connected: {session_id}")
|
|
||||||
|
|
||||||
# Initialize conversation data
|
|
||||||
if session_id not in active_conversations:
|
|
||||||
active_conversations[session_id] = Conversation(session_id)
|
|
||||||
user_queues[session_id] = queue.Queue()
|
|
||||||
processing_threads[session_id] = threading.Thread(
|
|
||||||
target=process_audio_queue,
|
|
||||||
args=(session_id, user_queues[session_id], app, socketio, models, active_conversations, DEVICE),
|
|
||||||
daemon=True
|
|
||||||
)
|
|
||||||
processing_threads[session_id].start()
|
|
||||||
|
|
||||||
emit('connection_status', {'status': 'connected'})
|
|
||||||
|
|
||||||
@socketio.on('disconnect')
|
|
||||||
def handle_disconnect(reason=None):
|
|
||||||
"""Handle client disconnection"""
|
|
||||||
session_id = request.sid
|
|
||||||
logger.info(f"Client disconnected: {session_id}. Reason: {reason}")
|
|
||||||
|
|
||||||
# Cleanup
|
|
||||||
if session_id in active_conversations:
|
|
||||||
# Mark for deletion rather than immediately removing
|
|
||||||
# as the processing thread might still be accessing it
|
|
||||||
active_conversations[session_id].is_processing = False
|
|
||||||
user_queues[session_id].put(None) # Signal thread to terminate
|
|
||||||
|
|
||||||
@socketio.on('audio_data')
|
|
||||||
def handle_audio_data(data):
|
|
||||||
"""Handle incoming audio data"""
|
|
||||||
session_id = request.sid
|
|
||||||
logger.info(f"Received audio data from {session_id}")
|
|
||||||
|
|
||||||
# Check if the models are loaded
|
|
||||||
if models.generator is None or models.whisperx_model is None or models.llm is None:
|
|
||||||
emit('error', {'message': 'Models still loading, please wait'})
|
|
||||||
return
|
|
||||||
|
|
||||||
# Check if we're already processing for this session
|
|
||||||
if session_id in active_conversations and active_conversations[session_id].is_processing:
|
|
||||||
emit('error', {'message': 'Still processing previous audio, please wait'})
|
|
||||||
return
|
|
||||||
|
|
||||||
# Add to processing queue
|
|
||||||
if session_id in user_queues:
|
|
||||||
user_queues[session_id].put(data)
|
|
||||||
else:
|
|
||||||
emit('error', {'message': 'Session not initialized, please refresh the page'})
|
|
||||||
|
|
||||||
def process_audio_queue(session_id, q, app, socketio, models, active_conversations, DEVICE):
|
|
||||||
"""Background thread to process audio chunks for a session"""
|
|
||||||
logger.info(f"Started processing thread for session: {session_id}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
while session_id in active_conversations:
|
|
||||||
try:
|
|
||||||
# Get the next audio chunk with a timeout
|
|
||||||
data = q.get(timeout=120)
|
|
||||||
if data is None: # Termination signal
|
|
||||||
break
|
|
||||||
|
|
||||||
# Process the audio and generate a response
|
|
||||||
process_audio_and_respond(session_id, data, app, socketio, models, active_conversations, DEVICE)
|
|
||||||
|
|
||||||
except queue.Empty:
|
|
||||||
# Timeout, check if session is still valid
|
|
||||||
continue
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error processing audio for {session_id}: {str(e)}")
|
|
||||||
# Create an app context for the socket emit
|
|
||||||
with app.app_context():
|
|
||||||
socketio.emit('error', {'message': str(e)}, room=session_id)
|
|
||||||
finally:
|
|
||||||
logger.info(f"Ending processing thread for session: {session_id}")
|
|
||||||
# Clean up when thread is done
|
|
||||||
with app.app_context():
|
|
||||||
if session_id in active_conversations:
|
|
||||||
del active_conversations[session_id]
|
|
||||||
if session_id in user_queues: # Use the passed-in reference
|
|
||||||
del user_queues[session_id]
|
|
||||||
|
|
||||||
def process_audio_and_respond(session_id, data, app, socketio, models, active_conversations, DEVICE):
|
|
||||||
"""Process audio data and generate a response using WhisperX"""
|
|
||||||
if models.generator is None or models.whisperx_model is None or models.llm is None:
|
|
||||||
logger.warning("Models not yet loaded!")
|
|
||||||
with app.app_context():
|
|
||||||
socketio.emit('error', {'message': 'Models still loading, please wait'}, room=session_id)
|
|
||||||
return
|
|
||||||
|
|
||||||
logger.info(f"Processing audio for session {session_id}")
|
|
||||||
conversation = active_conversations[session_id]
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Set processing flag
|
|
||||||
conversation.is_processing = True
|
|
||||||
|
|
||||||
# Process base64 audio data
|
|
||||||
audio_data = data['audio']
|
|
||||||
speaker_id = data['speaker']
|
|
||||||
logger.info(f"Received audio from speaker {speaker_id}")
|
|
||||||
|
|
||||||
# Convert from base64 to WAV
|
|
||||||
try:
|
|
||||||
audio_bytes = base64.b64decode(audio_data.split(',')[1])
|
|
||||||
logger.info(f"Decoded audio bytes: {len(audio_bytes)} bytes")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error decoding base64 audio: {str(e)}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
# Save to temporary file for processing
|
|
||||||
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as temp_file:
|
|
||||||
temp_file.write(audio_bytes)
|
|
||||||
temp_path = temp_file.name
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Notify client that transcription is starting
|
|
||||||
with app.app_context():
|
|
||||||
socketio.emit('processing_status', {'status': 'transcribing'}, room=session_id)
|
|
||||||
|
|
||||||
# Load audio using WhisperX
|
|
||||||
import whisperx
|
|
||||||
audio = whisperx.load_audio(temp_path)
|
|
||||||
|
|
||||||
# Check audio length and add a warning for short clips
|
|
||||||
audio_length = len(audio) / 16000 # assuming 16kHz sample rate
|
|
||||||
if audio_length < 1.0:
|
|
||||||
logger.warning(f"Audio is very short ({audio_length:.2f}s), may affect transcription quality")
|
|
||||||
|
|
||||||
# Transcribe using WhisperX
|
|
||||||
batch_size = 16 # adjust based on your GPU memory
|
|
||||||
logger.info("Running WhisperX transcription...")
|
|
||||||
|
|
||||||
# Handle the warning about audio being shorter than 30s by suppressing it
|
|
||||||
import warnings
|
|
||||||
with warnings.catch_warnings():
|
|
||||||
warnings.filterwarnings("ignore", message="audio is shorter than 30s")
|
|
||||||
result = models.whisperx_model.transcribe(audio, batch_size=batch_size)
|
|
||||||
|
|
||||||
# Get the detected language
|
|
||||||
language_code = result["language"]
|
|
||||||
logger.info(f"Detected language: {language_code}")
|
|
||||||
|
|
||||||
# Check if alignment model needs to be loaded or updated
|
|
||||||
if models.whisperx_align_model is None or language_code != models.last_language:
|
|
||||||
# Clean up old models if they exist
|
|
||||||
if models.whisperx_align_model is not None:
|
|
||||||
del models.whisperx_align_model
|
|
||||||
del models.whisperx_align_metadata
|
|
||||||
if DEVICE == "cuda":
|
|
||||||
gc.collect()
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
# Load new alignment model for the detected language
|
|
||||||
logger.info(f"Loading alignment model for language: {language_code}")
|
|
||||||
models.whisperx_align_model, models.whisperx_align_metadata = whisperx.load_align_model(
|
|
||||||
language_code=language_code, device=DEVICE
|
|
||||||
)
|
|
||||||
models.last_language = language_code
|
|
||||||
|
|
||||||
# Align the transcript to get word-level timestamps
|
|
||||||
if result["segments"] and len(result["segments"]) > 0:
|
|
||||||
logger.info("Aligning transcript...")
|
|
||||||
result = whisperx.align(
|
|
||||||
result["segments"],
|
|
||||||
models.whisperx_align_model,
|
|
||||||
models.whisperx_align_metadata,
|
|
||||||
audio,
|
|
||||||
DEVICE,
|
|
||||||
return_char_alignments=False
|
|
||||||
)
|
|
||||||
|
|
||||||
# Process the segments for better output
|
|
||||||
for segment in result["segments"]:
|
|
||||||
# Round timestamps for better display
|
|
||||||
segment["start"] = round(segment["start"], 2)
|
|
||||||
segment["end"] = round(segment["end"], 2)
|
|
||||||
# Add a confidence score if not present
|
|
||||||
if "confidence" not in segment:
|
|
||||||
segment["confidence"] = 1.0 # Default confidence
|
|
||||||
|
|
||||||
# Extract the full text from all segments
|
|
||||||
user_text = ' '.join([segment['text'] for segment in result['segments']])
|
|
||||||
|
|
||||||
# If no text was recognized, don't process further
|
|
||||||
if not user_text or len(user_text.strip()) == 0:
|
|
||||||
with app.app_context():
|
|
||||||
socketio.emit('error', {'message': 'No speech detected'}, room=session_id)
|
|
||||||
return
|
|
||||||
|
|
||||||
logger.info(f"Transcription: {user_text}")
|
|
||||||
|
|
||||||
# Load audio for CSM input
|
|
||||||
waveform, sample_rate = torchaudio.load(temp_path)
|
|
||||||
|
|
||||||
# Normalize to mono if needed
|
|
||||||
if waveform.shape[0] > 1:
|
|
||||||
waveform = torch.mean(waveform, dim=0, keepdim=True)
|
|
||||||
|
|
||||||
# Resample to the CSM sample rate if needed
|
|
||||||
if sample_rate != models.generator.sample_rate:
|
|
||||||
waveform = torchaudio.functional.resample(
|
|
||||||
waveform,
|
|
||||||
orig_freq=sample_rate,
|
|
||||||
new_freq=models.generator.sample_rate
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add the user's message to conversation history
|
|
||||||
user_segment = conversation.add_segment(
|
|
||||||
text=user_text,
|
|
||||||
speaker=speaker_id,
|
|
||||||
audio=waveform.squeeze()
|
|
||||||
)
|
|
||||||
|
|
||||||
# Send transcription to client with detailed segments
|
|
||||||
with app.app_context():
|
|
||||||
socketio.emit('transcription', {
|
|
||||||
'text': user_text,
|
|
||||||
'speaker': speaker_id,
|
|
||||||
'segments': result['segments'] # Include the detailed segments with timestamps
|
|
||||||
}, room=session_id)
|
|
||||||
|
|
||||||
# Generate AI response using Llama
|
|
||||||
with app.app_context():
|
|
||||||
socketio.emit('processing_status', {'status': 'generating'}, room=session_id)
|
|
||||||
|
|
||||||
# Create prompt from conversation history
|
|
||||||
conversation_history = ""
|
|
||||||
for segment in conversation.segments[-5:]: # Last 5 segments for context
|
|
||||||
role = "User" if segment.speaker == 0 else "Assistant"
|
|
||||||
conversation_history += f"{role}: {segment.text}\n"
|
|
||||||
|
|
||||||
# Add final prompt
|
|
||||||
prompt = f"{conversation_history}Assistant: "
|
|
||||||
|
|
||||||
# Generate response with Llama
|
|
||||||
try:
|
|
||||||
# Ensure pad token is set
|
|
||||||
if models.tokenizer.pad_token is None:
|
|
||||||
models.tokenizer.pad_token = models.tokenizer.eos_token
|
|
||||||
|
|
||||||
input_tokens = models.tokenizer(
|
|
||||||
prompt,
|
|
||||||
return_tensors="pt",
|
|
||||||
padding=True,
|
|
||||||
return_attention_mask=True
|
|
||||||
)
|
|
||||||
input_ids = input_tokens.input_ids.to(DEVICE)
|
|
||||||
attention_mask = input_tokens.attention_mask.to(DEVICE)
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
generated_ids = models.llm.generate(
|
|
||||||
input_ids,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
max_new_tokens=100,
|
|
||||||
temperature=0.7,
|
|
||||||
top_p=0.9,
|
|
||||||
do_sample=True,
|
|
||||||
pad_token_id=models.tokenizer.eos_token_id
|
|
||||||
)
|
|
||||||
|
|
||||||
# Decode the response
|
|
||||||
response_text = models.tokenizer.decode(
|
|
||||||
generated_ids[0][input_ids.shape[1]:],
|
|
||||||
skip_special_tokens=True
|
|
||||||
).strip()
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error generating response: {str(e)}")
|
|
||||||
logger.error(traceback.format_exc())
|
|
||||||
response_text = "I'm sorry, I encountered an error while processing your request."
|
|
||||||
|
|
||||||
# Synthesize speech
|
|
||||||
with app.app_context():
|
|
||||||
socketio.emit('processing_status', {'status': 'synthesizing'}, room=session_id)
|
|
||||||
|
|
||||||
# Start sending the audio response
|
|
||||||
socketio.emit('audio_response_start', {
|
|
||||||
'text': response_text,
|
|
||||||
'total_chunks': 1,
|
|
||||||
'chunk_index': 0
|
|
||||||
}, room=session_id)
|
|
||||||
|
|
||||||
# Define AI speaker ID
|
|
||||||
ai_speaker_id = conversation.ai_speaker_id
|
|
||||||
|
|
||||||
# Generate audio
|
|
||||||
audio_tensor = models.generator.generate(
|
|
||||||
text=response_text,
|
|
||||||
speaker=ai_speaker_id,
|
|
||||||
context=conversation.get_context(),
|
|
||||||
max_audio_length_ms=10_000,
|
|
||||||
temperature=0.9
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add AI response to conversation history
|
|
||||||
ai_segment = conversation.add_segment(
|
|
||||||
text=response_text,
|
|
||||||
speaker=ai_speaker_id,
|
|
||||||
audio=audio_tensor
|
|
||||||
)
|
|
||||||
|
|
||||||
# Convert audio to WAV format
|
|
||||||
with io.BytesIO() as wav_io:
|
|
||||||
torchaudio.save(
|
|
||||||
wav_io,
|
|
||||||
audio_tensor.unsqueeze(0).cpu(),
|
|
||||||
models.generator.sample_rate,
|
|
||||||
format="wav"
|
|
||||||
)
|
|
||||||
wav_io.seek(0)
|
|
||||||
wav_data = wav_io.read()
|
|
||||||
|
|
||||||
# Convert WAV data to base64
|
|
||||||
audio_base64 = f"data:audio/wav;base64,{base64.b64encode(wav_data).decode('utf-8')}"
|
|
||||||
|
|
||||||
# Send audio chunk to client
|
|
||||||
with app.app_context():
|
|
||||||
socketio.emit('audio_response_chunk', {
|
|
||||||
'chunk': audio_base64,
|
|
||||||
'chunk_index': 0,
|
|
||||||
'total_chunks': 1,
|
|
||||||
'is_last': True
|
|
||||||
}, room=session_id)
|
|
||||||
|
|
||||||
# Signal completion
|
|
||||||
socketio.emit('audio_response_complete', {
|
|
||||||
'text': response_text
|
|
||||||
}, room=session_id)
|
|
||||||
|
|
||||||
finally:
|
|
||||||
# Clean up temp file
|
|
||||||
if os.path.exists(temp_path):
|
|
||||||
os.unlink(temp_path)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error processing audio: {str(e)}")
|
|
||||||
logger.error(traceback.format_exc())
|
|
||||||
with app.app_context():
|
|
||||||
socketio.emit('error', {'message': f'Error: {str(e)}'}, room=session_id)
|
|
||||||
finally:
|
|
||||||
# Reset processing flag
|
|
||||||
conversation.is_processing = False
|
|
||||||
229
Backend/app.py
Normal file
229
Backend/app.py
Normal file
@@ -0,0 +1,229 @@
|
|||||||
|
import os
|
||||||
|
import io
|
||||||
|
import base64
|
||||||
|
import time
|
||||||
|
import torch
|
||||||
|
import torchaudio
|
||||||
|
import numpy as np
|
||||||
|
from flask import Flask, render_template, request
|
||||||
|
from flask_socketio import SocketIO, emit
|
||||||
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
import speech_recognition as sr
|
||||||
|
from generator import load_csm_1b, Segment
|
||||||
|
from collections import deque
|
||||||
|
|
||||||
|
app = Flask(__name__)
|
||||||
|
app.config['SECRET_KEY'] = 'your-secret-key'
|
||||||
|
socketio = SocketIO(app, cors_allowed_origins="*")
|
||||||
|
|
||||||
|
# Select the best available device
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = "cuda"
|
||||||
|
elif torch.backends.mps.is_available():
|
||||||
|
device = "mps"
|
||||||
|
else:
|
||||||
|
device = "cpu"
|
||||||
|
print(f"Using device: {device}")
|
||||||
|
|
||||||
|
# Initialize CSM model for audio generation
|
||||||
|
print("Loading CSM model...")
|
||||||
|
csm_generator = load_csm_1b(device=device)
|
||||||
|
|
||||||
|
# Initialize Llama 3.2 model for response generation
|
||||||
|
print("Loading Llama 3.2 model...")
|
||||||
|
llm_model_id = "meta-llama/Llama-3.2-1B" # Choose appropriate size based on resources
|
||||||
|
llm_tokenizer = AutoTokenizer.from_pretrained(llm_model_id)
|
||||||
|
llm_model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
llm_model_id,
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device_map=device
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize speech recognition
|
||||||
|
recognizer = sr.Recognizer()
|
||||||
|
|
||||||
|
# Store conversation context
|
||||||
|
conversation_context = {} # session_id -> context
|
||||||
|
|
||||||
|
@app.route('/')
|
||||||
|
def index():
|
||||||
|
return render_template('index.html')
|
||||||
|
|
||||||
|
@socketio.on('connect')
|
||||||
|
def handle_connect():
|
||||||
|
print(f"Client connected: {request.sid}")
|
||||||
|
conversation_context[request.sid] = {
|
||||||
|
'segments': [],
|
||||||
|
'speakers': [0, 1], # 0 = user, 1 = bot
|
||||||
|
'audio_buffer': deque(maxlen=10), # Store recent audio chunks
|
||||||
|
'is_speaking': False,
|
||||||
|
'silence_start': None
|
||||||
|
}
|
||||||
|
emit('ready', {'message': 'Connection established'})
|
||||||
|
|
||||||
|
@socketio.on('disconnect')
|
||||||
|
def handle_disconnect():
|
||||||
|
print(f"Client disconnected: {request.sid}")
|
||||||
|
if request.sid in conversation_context:
|
||||||
|
del conversation_context[request.sid]
|
||||||
|
|
||||||
|
@socketio.on('start_speaking')
|
||||||
|
def handle_start_speaking():
|
||||||
|
if request.sid in conversation_context:
|
||||||
|
conversation_context[request.sid]['is_speaking'] = True
|
||||||
|
conversation_context[request.sid]['audio_buffer'].clear()
|
||||||
|
print(f"User {request.sid} started speaking")
|
||||||
|
|
||||||
|
@socketio.on('audio_chunk')
|
||||||
|
def handle_audio_chunk(data):
|
||||||
|
if request.sid not in conversation_context:
|
||||||
|
return
|
||||||
|
|
||||||
|
context = conversation_context[request.sid]
|
||||||
|
|
||||||
|
# Decode audio data
|
||||||
|
audio_data = base64.b64decode(data['audio'])
|
||||||
|
audio_numpy = np.frombuffer(audio_data, dtype=np.float32)
|
||||||
|
audio_tensor = torch.tensor(audio_numpy)
|
||||||
|
|
||||||
|
# Add to buffer
|
||||||
|
context['audio_buffer'].append(audio_tensor)
|
||||||
|
|
||||||
|
# Check for silence to detect end of speech
|
||||||
|
if context['is_speaking'] and is_silence(audio_tensor):
|
||||||
|
if context['silence_start'] is None:
|
||||||
|
context['silence_start'] = time.time()
|
||||||
|
elif time.time() - context['silence_start'] > 1.0: # 1 second of silence
|
||||||
|
# Process the complete utterance
|
||||||
|
process_user_utterance(request.sid)
|
||||||
|
else:
|
||||||
|
context['silence_start'] = None
|
||||||
|
|
||||||
|
@socketio.on('stop_speaking')
|
||||||
|
def handle_stop_speaking():
|
||||||
|
if request.sid in conversation_context:
|
||||||
|
conversation_context[request.sid]['is_speaking'] = False
|
||||||
|
process_user_utterance(request.sid)
|
||||||
|
print(f"User {request.sid} stopped speaking")
|
||||||
|
|
||||||
|
def is_silence(audio_tensor, threshold=0.02):
|
||||||
|
"""Check if an audio chunk is silence based on amplitude threshold"""
|
||||||
|
return torch.mean(torch.abs(audio_tensor)) < threshold
|
||||||
|
|
||||||
|
def process_user_utterance(session_id):
|
||||||
|
"""Process completed user utterance, generate response and send audio back"""
|
||||||
|
context = conversation_context[session_id]
|
||||||
|
|
||||||
|
if not context['audio_buffer']:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Combine audio chunks
|
||||||
|
full_audio = torch.cat(list(context['audio_buffer']), dim=0)
|
||||||
|
context['audio_buffer'].clear()
|
||||||
|
context['is_speaking'] = False
|
||||||
|
context['silence_start'] = None
|
||||||
|
|
||||||
|
# Convert audio to 16kHz for speech recognition
|
||||||
|
audio_16k = torchaudio.functional.resample(
|
||||||
|
full_audio,
|
||||||
|
orig_freq=44100, # Assuming 44.1kHz from client
|
||||||
|
new_freq=16000
|
||||||
|
)
|
||||||
|
|
||||||
|
# Transcribe speech
|
||||||
|
try:
|
||||||
|
# Convert to wav format for speech_recognition
|
||||||
|
audio_data = io.BytesIO()
|
||||||
|
torchaudio.save(audio_data, audio_16k.unsqueeze(0), 16000, format="wav")
|
||||||
|
audio_data.seek(0)
|
||||||
|
|
||||||
|
with sr.AudioFile(audio_data) as source:
|
||||||
|
audio = recognizer.record(source)
|
||||||
|
user_text = recognizer.recognize_google(audio)
|
||||||
|
print(f"Transcribed: {user_text}")
|
||||||
|
|
||||||
|
# Add to conversation segments
|
||||||
|
user_segment = Segment(
|
||||||
|
text=user_text,
|
||||||
|
speaker=0, # User is speaker 0
|
||||||
|
audio=full_audio
|
||||||
|
)
|
||||||
|
context['segments'].append(user_segment)
|
||||||
|
|
||||||
|
# Generate bot response
|
||||||
|
bot_response = generate_llm_response(user_text, context['segments'])
|
||||||
|
print(f"Bot response: {bot_response}")
|
||||||
|
|
||||||
|
# Convert to audio using CSM
|
||||||
|
bot_audio = generate_audio_response(bot_response, context['segments'])
|
||||||
|
|
||||||
|
# Convert audio to base64 for sending over websocket
|
||||||
|
audio_bytes = io.BytesIO()
|
||||||
|
torchaudio.save(audio_bytes, bot_audio.unsqueeze(0).cpu(), csm_generator.sample_rate, format="wav")
|
||||||
|
audio_bytes.seek(0)
|
||||||
|
audio_b64 = base64.b64encode(audio_bytes.read()).decode('utf-8')
|
||||||
|
|
||||||
|
# Add bot response to conversation history
|
||||||
|
bot_segment = Segment(
|
||||||
|
text=bot_response,
|
||||||
|
speaker=1, # Bot is speaker 1
|
||||||
|
audio=bot_audio
|
||||||
|
)
|
||||||
|
context['segments'].append(bot_segment)
|
||||||
|
|
||||||
|
# Send transcribed text to client
|
||||||
|
emit('transcription', {'text': user_text}, room=session_id)
|
||||||
|
|
||||||
|
# Send audio response to client
|
||||||
|
emit('audio_response', {
|
||||||
|
'audio': audio_b64,
|
||||||
|
'text': bot_response
|
||||||
|
}, room=session_id)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error processing speech: {e}")
|
||||||
|
emit('error', {'message': f'Error processing speech: {str(e)}'}, room=session_id)
|
||||||
|
|
||||||
|
def generate_llm_response(user_text, conversation_segments):
|
||||||
|
"""Generate text response using Llama 3.2"""
|
||||||
|
# Format conversation history for the LLM
|
||||||
|
conversation_history = ""
|
||||||
|
for segment in conversation_segments[-5:]: # Use last 5 utterances for context
|
||||||
|
speaker_name = "User" if segment.speaker == 0 else "Assistant"
|
||||||
|
conversation_history += f"{speaker_name}: {segment.text}\n"
|
||||||
|
|
||||||
|
# Add the current user query
|
||||||
|
conversation_history += f"User: {user_text}\nAssistant:"
|
||||||
|
|
||||||
|
# Generate response
|
||||||
|
inputs = llm_tokenizer(conversation_history, return_tensors="pt").to(device)
|
||||||
|
output = llm_model.generate(
|
||||||
|
inputs.input_ids,
|
||||||
|
max_new_tokens=150,
|
||||||
|
temperature=0.7,
|
||||||
|
top_p=0.9,
|
||||||
|
do_sample=True
|
||||||
|
)
|
||||||
|
|
||||||
|
response = llm_tokenizer.decode(output[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
|
||||||
|
return response.strip()
|
||||||
|
|
||||||
|
def generate_audio_response(text, conversation_segments):
|
||||||
|
"""Generate audio response using CSM"""
|
||||||
|
# Use the last few conversation segments as context
|
||||||
|
context_segments = conversation_segments[-4:] if len(conversation_segments) > 4 else conversation_segments
|
||||||
|
|
||||||
|
# Generate audio for bot response
|
||||||
|
audio = csm_generator.generate(
|
||||||
|
text=text,
|
||||||
|
speaker=1, # Bot is speaker 1
|
||||||
|
context=context_segments,
|
||||||
|
max_audio_length_ms=10000, # 10 seconds max
|
||||||
|
temperature=0.9,
|
||||||
|
topk=50
|
||||||
|
)
|
||||||
|
|
||||||
|
return audio
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
socketio.run(app, host='0.0.0.0', port=5000, debug=True)
|
||||||
212
Backend/index.html
Normal file
212
Backend/index.html
Normal file
@@ -0,0 +1,212 @@
|
|||||||
|
|
||||||
|
<!DOCTYPE html>
|
||||||
|
<html lang="en">
|
||||||
|
<head>
|
||||||
|
<meta charset="UTF-8">
|
||||||
|
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||||
|
<title>Audio Conversation Bot</title>
|
||||||
|
<script src="https://cdn.socket.io/4.6.0/socket.io.min.js"></script>
|
||||||
|
<style>
|
||||||
|
body {
|
||||||
|
font-family: Arial, sans-serif;
|
||||||
|
max-width: 800px;
|
||||||
|
margin: 0 auto;
|
||||||
|
padding: 20px;
|
||||||
|
}
|
||||||
|
#conversation {
|
||||||
|
height: 400px;
|
||||||
|
border: 1px solid #ccc;
|
||||||
|
padding: 15px;
|
||||||
|
margin-bottom: 20px;
|
||||||
|
overflow-y: auto;
|
||||||
|
}
|
||||||
|
.user-message {
|
||||||
|
background-color: #e1f5fe;
|
||||||
|
padding: 10px;
|
||||||
|
border-radius: 8px;
|
||||||
|
margin-bottom: 10px;
|
||||||
|
align-self: flex-end;
|
||||||
|
}
|
||||||
|
.bot-message {
|
||||||
|
background-color: #f1f1f1;
|
||||||
|
padding: 10px;
|
||||||
|
border-radius: 8px;
|
||||||
|
margin-bottom: 10px;
|
||||||
|
}
|
||||||
|
#controls {
|
||||||
|
display: flex;
|
||||||
|
gap: 10px;
|
||||||
|
}
|
||||||
|
button {
|
||||||
|
padding: 10px 20px;
|
||||||
|
font-size: 16px;
|
||||||
|
cursor: pointer;
|
||||||
|
}
|
||||||
|
#recordButton {
|
||||||
|
background-color: #4CAF50;
|
||||||
|
color: white;
|
||||||
|
border: none;
|
||||||
|
border-radius: 4px;
|
||||||
|
}
|
||||||
|
#recordButton.recording {
|
||||||
|
background-color: #f44336;
|
||||||
|
}
|
||||||
|
#status {
|
||||||
|
margin-top: 10px;
|
||||||
|
font-style: italic;
|
||||||
|
}
|
||||||
|
</style>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<h1>Audio Conversation Bot</h1>
|
||||||
|
<div id="conversation"></div>
|
||||||
|
<div id="controls">
|
||||||
|
<button id="recordButton">Hold to Speak</button>
|
||||||
|
</div>
|
||||||
|
<div id="status">Not connected</div>
|
||||||
|
|
||||||
|
<script>
|
||||||
|
const socket = io();
|
||||||
|
const recordButton = document.getElementById('recordButton');
|
||||||
|
const conversation = document.getElementById('conversation');
|
||||||
|
const status = document.getElementById('status');
|
||||||
|
|
||||||
|
let mediaRecorder;
|
||||||
|
let audioChunks = [];
|
||||||
|
let isRecording = false;
|
||||||
|
|
||||||
|
// Initialize audio context and analyzer
|
||||||
|
const audioContext = new (window.AudioContext || window.webkitAudioContext)();
|
||||||
|
|
||||||
|
// Connect to server
|
||||||
|
socket.on('connect', () => {
|
||||||
|
status.textContent = 'Connected to server';
|
||||||
|
});
|
||||||
|
|
||||||
|
socket.on('ready', (data) => {
|
||||||
|
status.textContent = data.message;
|
||||||
|
setupAudioRecording();
|
||||||
|
});
|
||||||
|
|
||||||
|
socket.on('transcription', (data) => {
|
||||||
|
addMessage('user', data.text);
|
||||||
|
});
|
||||||
|
|
||||||
|
socket.on('audio_response', (data) => {
|
||||||
|
// Play audio
|
||||||
|
const audio = new Audio('data:audio/wav;base64,' + data.audio);
|
||||||
|
audio.play();
|
||||||
|
|
||||||
|
// Display text
|
||||||
|
addMessage('bot', data.text);
|
||||||
|
});
|
||||||
|
|
||||||
|
socket.on('error', (data) => {
|
||||||
|
status.textContent = data.message;
|
||||||
|
console.error(data.message);
|
||||||
|
});
|
||||||
|
|
||||||
|
function setupAudioRecording() {
|
||||||
|
// Get user media
|
||||||
|
navigator.mediaDevices.getUserMedia({ audio: true })
|
||||||
|
.then(stream => {
|
||||||
|
// Setup recording
|
||||||
|
mediaRecorder = new MediaRecorder(stream);
|
||||||
|
|
||||||
|
mediaRecorder.ondataavailable = event => {
|
||||||
|
if (event.data.size > 0) {
|
||||||
|
audioChunks.push(event.data);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
mediaRecorder.onstop = () => {
|
||||||
|
const audioBlob = new Blob(audioChunks, { type: 'audio/wav' });
|
||||||
|
audioChunks = [];
|
||||||
|
|
||||||
|
// Convert to Float32Array for sending
|
||||||
|
const fileReader = new FileReader();
|
||||||
|
fileReader.onloadend = () => {
|
||||||
|
const arrayBuffer = fileReader.result;
|
||||||
|
const floatArray = new Float32Array(arrayBuffer);
|
||||||
|
|
||||||
|
// Convert to base64
|
||||||
|
const base64String = arrayBufferToBase64(floatArray.buffer);
|
||||||
|
socket.emit('audio_chunk', { audio: base64String });
|
||||||
|
};
|
||||||
|
fileReader.readAsArrayBuffer(audioBlob);
|
||||||
|
|
||||||
|
socket.emit('stop_speaking');
|
||||||
|
isRecording = false;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Setup audio analyzer for chunking and VAD
|
||||||
|
const source = audioContext.createMediaStreamSource(stream);
|
||||||
|
const analyzer = audioContext.createAnalyser();
|
||||||
|
analyzer.fftSize = 2048;
|
||||||
|
source.connect(analyzer);
|
||||||
|
|
||||||
|
// Setup button handlers
|
||||||
|
recordButton.addEventListener('mousedown', startRecording);
|
||||||
|
recordButton.addEventListener('touchstart', startRecording);
|
||||||
|
recordButton.addEventListener('mouseup', stopRecording);
|
||||||
|
recordButton.addEventListener('touchend', stopRecording);
|
||||||
|
recordButton.addEventListener('mouseleave', stopRecording);
|
||||||
|
|
||||||
|
status.textContent = 'Ready to record';
|
||||||
|
})
|
||||||
|
.catch(err => {
|
||||||
|
status.textContent = 'Error accessing microphone: ' + err.message;
|
||||||
|
console.error('Error accessing microphone:', err);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
function startRecording() {
|
||||||
|
if (!isRecording) {
|
||||||
|
audioChunks = [];
|
||||||
|
mediaRecorder.start(100); // Collect data in 100ms chunks
|
||||||
|
recordButton.classList.add('recording');
|
||||||
|
recordButton.textContent = 'Release to Stop';
|
||||||
|
status.textContent = 'Recording...';
|
||||||
|
isRecording = true;
|
||||||
|
|
||||||
|
socket.emit('start_speaking');
|
||||||
|
|
||||||
|
// Start sending audio chunks periodically
|
||||||
|
audioSendInterval = setInterval(() => {
|
||||||
|
if (mediaRecorder.state === 'recording') {
|
||||||
|
mediaRecorder.requestData(); // Force ondataavailable to fire
|
||||||
|
}
|
||||||
|
}, 300); // Send every 300ms
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function stopRecording() {
|
||||||
|
if (isRecording) {
|
||||||
|
clearInterval(audioSendInterval);
|
||||||
|
mediaRecorder.stop();
|
||||||
|
recordButton.classList.remove('recording');
|
||||||
|
recordButton.textContent = 'Hold to Speak';
|
||||||
|
status.textContent = 'Processing...';
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function addMessage(sender, text) {
|
||||||
|
const messageDiv = document.createElement('div');
|
||||||
|
messageDiv.className = sender === 'user' ? 'user-message' : 'bot-message';
|
||||||
|
messageDiv.textContent = text;
|
||||||
|
conversation.appendChild(messageDiv);
|
||||||
|
conversation.scrollTop = conversation.scrollHeight;
|
||||||
|
}
|
||||||
|
|
||||||
|
function arrayBufferToBase64(buffer) {
|
||||||
|
let binary = '';
|
||||||
|
const bytes = new Uint8Array(buffer);
|
||||||
|
const len = bytes.byteLength;
|
||||||
|
for (let i = 0; i < len; i++) {
|
||||||
|
binary += String.fromCharCode(bytes[i]);
|
||||||
|
}
|
||||||
|
return window.btoa(binary);
|
||||||
|
}
|
||||||
|
</script>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
117
Backend/run_csm.py
Normal file
117
Backend/run_csm.py
Normal file
@@ -0,0 +1,117 @@
|
|||||||
|
import os
|
||||||
|
import torch
|
||||||
|
import torchaudio
|
||||||
|
from huggingface_hub import hf_hub_download
|
||||||
|
from generator import load_csm_1b, Segment
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
# Disable Triton compilation
|
||||||
|
os.environ["NO_TORCH_COMPILE"] = "1"
|
||||||
|
|
||||||
|
# Default prompts are available at https://hf.co/sesame/csm-1b
|
||||||
|
prompt_filepath_conversational_a = hf_hub_download(
|
||||||
|
repo_id="sesame/csm-1b",
|
||||||
|
filename="prompts/conversational_a.wav"
|
||||||
|
)
|
||||||
|
prompt_filepath_conversational_b = hf_hub_download(
|
||||||
|
repo_id="sesame/csm-1b",
|
||||||
|
filename="prompts/conversational_b.wav"
|
||||||
|
)
|
||||||
|
|
||||||
|
SPEAKER_PROMPTS = {
|
||||||
|
"conversational_a": {
|
||||||
|
"text": (
|
||||||
|
"like revising for an exam I'd have to try and like keep up the momentum because I'd "
|
||||||
|
"start really early I'd be like okay I'm gonna start revising now and then like "
|
||||||
|
"you're revising for ages and then I just like start losing steam I didn't do that "
|
||||||
|
"for the exam we had recently to be fair that was a more of a last minute scenario "
|
||||||
|
"but like yeah I'm trying to like yeah I noticed this yesterday that like Mondays I "
|
||||||
|
"sort of start the day with this not like a panic but like a"
|
||||||
|
),
|
||||||
|
"audio": prompt_filepath_conversational_a
|
||||||
|
},
|
||||||
|
"conversational_b": {
|
||||||
|
"text": (
|
||||||
|
"like a super Mario level. Like it's very like high detail. And like, once you get "
|
||||||
|
"into the park, it just like, everything looks like a computer game and they have all "
|
||||||
|
"these, like, you know, if, if there's like a, you know, like in a Mario game, they "
|
||||||
|
"will have like a question block. And if you like, you know, punch it, a coin will "
|
||||||
|
"come out. So like everyone, when they come into the park, they get like this little "
|
||||||
|
"bracelet and then you can go punching question blocks around."
|
||||||
|
),
|
||||||
|
"audio": prompt_filepath_conversational_b
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
def load_prompt_audio(audio_path: str, target_sample_rate: int) -> torch.Tensor:
|
||||||
|
audio_tensor, sample_rate = torchaudio.load(audio_path)
|
||||||
|
audio_tensor = audio_tensor.squeeze(0)
|
||||||
|
# Resample is lazy so we can always call it
|
||||||
|
audio_tensor = torchaudio.functional.resample(
|
||||||
|
audio_tensor, orig_freq=sample_rate, new_freq=target_sample_rate
|
||||||
|
)
|
||||||
|
return audio_tensor
|
||||||
|
|
||||||
|
def prepare_prompt(text: str, speaker: int, audio_path: str, sample_rate: int) -> Segment:
|
||||||
|
audio_tensor = load_prompt_audio(audio_path, sample_rate)
|
||||||
|
return Segment(text=text, speaker=speaker, audio=audio_tensor)
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# Select the best available device, skipping MPS due to float64 limitations
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = "cuda"
|
||||||
|
else:
|
||||||
|
device = "cpu"
|
||||||
|
print(f"Using device: {device}")
|
||||||
|
|
||||||
|
# Load model
|
||||||
|
generator = load_csm_1b(device)
|
||||||
|
|
||||||
|
# Prepare prompts
|
||||||
|
prompt_a = prepare_prompt(
|
||||||
|
SPEAKER_PROMPTS["conversational_a"]["text"],
|
||||||
|
0,
|
||||||
|
SPEAKER_PROMPTS["conversational_a"]["audio"],
|
||||||
|
generator.sample_rate
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt_b = prepare_prompt(
|
||||||
|
SPEAKER_PROMPTS["conversational_b"]["text"],
|
||||||
|
1,
|
||||||
|
SPEAKER_PROMPTS["conversational_b"]["audio"],
|
||||||
|
generator.sample_rate
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate conversation
|
||||||
|
conversation = [
|
||||||
|
{"text": "Hey how are you doing?", "speaker_id": 0},
|
||||||
|
{"text": "Pretty good, pretty good. How about you?", "speaker_id": 1},
|
||||||
|
{"text": "I'm great! So happy to be speaking with you today.", "speaker_id": 0},
|
||||||
|
{"text": "Me too! This is some cool stuff, isn't it?", "speaker_id": 1}
|
||||||
|
]
|
||||||
|
|
||||||
|
# Generate each utterance
|
||||||
|
generated_segments = []
|
||||||
|
prompt_segments = [prompt_a, prompt_b]
|
||||||
|
|
||||||
|
for utterance in conversation:
|
||||||
|
print(f"Generating: {utterance['text']}")
|
||||||
|
audio_tensor = generator.generate(
|
||||||
|
text=utterance['text'],
|
||||||
|
speaker=utterance['speaker_id'],
|
||||||
|
context=prompt_segments + generated_segments,
|
||||||
|
max_audio_length_ms=10_000,
|
||||||
|
)
|
||||||
|
generated_segments.append(Segment(text=utterance['text'], speaker=utterance['speaker_id'], audio=audio_tensor))
|
||||||
|
|
||||||
|
# Concatenate all generations
|
||||||
|
all_audio = torch.cat([seg.audio for seg in generated_segments], dim=0)
|
||||||
|
torchaudio.save(
|
||||||
|
"full_conversation.wav",
|
||||||
|
all_audio.unsqueeze(0).cpu(),
|
||||||
|
generator.sample_rate
|
||||||
|
)
|
||||||
|
print("Successfully generated full_conversation.wav")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -1,19 +0,0 @@
|
|||||||
"""
|
|
||||||
CSM Voice Chat Server
|
|
||||||
A voice chat application that uses CSM 1B for voice synthesis,
|
|
||||||
WhisperX for speech recognition, and Llama 3.2 for language generation.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Start the Flask application
|
|
||||||
from api.app import app, socketio
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
import os
|
|
||||||
|
|
||||||
port = int(os.environ.get('PORT', 5000))
|
|
||||||
debug_mode = os.environ.get('DEBUG', 'False').lower() == 'true'
|
|
||||||
|
|
||||||
print(f"Starting server on port {port} (debug={debug_mode})")
|
|
||||||
print("Visit http://localhost:5000 to access the application")
|
|
||||||
|
|
||||||
socketio.run(app, host='0.0.0.0', port=port, debug=debug_mode, allow_unsafe_werkzeug=True)
|
|
||||||
Reference in New Issue
Block a user