136 lines
5.3 KiB
Python
136 lines
5.3 KiB
Python
import os
|
|
import logging
|
|
import threading
|
|
from dataclasses import dataclass
|
|
from flask import Flask
|
|
from flask_socketio import SocketIO
|
|
from flask_cors import CORS
|
|
|
|
# Configure logging
|
|
logging.basicConfig(level=logging.INFO,
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Configure device
|
|
import torch
|
|
if torch.cuda.is_available():
|
|
DEVICE = "cuda"
|
|
elif torch.backends.mps.is_available():
|
|
DEVICE = "mps"
|
|
else:
|
|
DEVICE = "cpu"
|
|
|
|
logger.info(f"Using device: {DEVICE}")
|
|
|
|
# Initialize Flask app
|
|
app = Flask(__name__, static_folder='../', static_url_path='')
|
|
CORS(app)
|
|
socketio = SocketIO(app, cors_allowed_origins="*", ping_timeout=120)
|
|
|
|
# Global variables for conversation state
|
|
active_conversations = {}
|
|
user_queues = {}
|
|
processing_threads = {}
|
|
|
|
# Model storage
|
|
@dataclass
|
|
class AppModels:
|
|
generator = None
|
|
tokenizer = None
|
|
llm = None
|
|
whisperx_model = None
|
|
whisperx_align_model = None
|
|
whisperx_align_metadata = None
|
|
last_language = None
|
|
|
|
models = AppModels()
|
|
|
|
def load_models():
|
|
"""Load all required models"""
|
|
from generator import load_csm_1b
|
|
import whisperx
|
|
import gc
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
global models
|
|
|
|
socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 0})
|
|
|
|
# CSM 1B loading
|
|
try:
|
|
socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 10, 'message': 'Loading CSM voice model'})
|
|
models.generator = load_csm_1b(device=DEVICE)
|
|
logger.info("CSM 1B model loaded successfully")
|
|
socketio.emit('model_status', {'model': 'csm', 'status': 'loaded'})
|
|
socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 33})
|
|
if DEVICE == "cuda":
|
|
torch.cuda.empty_cache()
|
|
except Exception as e:
|
|
import traceback
|
|
error_details = traceback.format_exc()
|
|
logger.error(f"Error loading CSM 1B model: {str(e)}\n{error_details}")
|
|
socketio.emit('model_status', {'model': 'csm', 'status': 'error', 'message': str(e)})
|
|
|
|
# WhisperX loading
|
|
try:
|
|
socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 40, 'message': 'Loading speech recognition model'})
|
|
# Use WhisperX for better transcription with timestamps
|
|
# Use compute_type based on device
|
|
compute_type = "float16" if DEVICE == "cuda" else "float32"
|
|
|
|
# Load the WhisperX model (smaller model for faster processing)
|
|
models.whisperx_model = whisperx.load_model("small", DEVICE, compute_type=compute_type)
|
|
|
|
logger.info("WhisperX model loaded successfully")
|
|
socketio.emit('model_status', {'model': 'asr', 'status': 'loaded'})
|
|
socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 66})
|
|
if DEVICE == "cuda":
|
|
torch.cuda.empty_cache()
|
|
except Exception as e:
|
|
import traceback
|
|
error_details = traceback.format_exc()
|
|
logger.error(f"Error loading WhisperX model: {str(e)}\n{error_details}")
|
|
socketio.emit('model_status', {'model': 'asr', 'status': 'error', 'message': str(e)})
|
|
|
|
# Llama loading
|
|
try:
|
|
socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 70, 'message': 'Loading language model'})
|
|
models.llm = AutoModelForCausalLM.from_pretrained(
|
|
"meta-llama/Llama-3.2-1B",
|
|
device_map=DEVICE,
|
|
torch_dtype=torch.bfloat16
|
|
)
|
|
models.tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B")
|
|
|
|
# Configure all special tokens
|
|
models.tokenizer.pad_token = models.tokenizer.eos_token
|
|
models.tokenizer.padding_side = "left" # For causal language modeling
|
|
|
|
# Inform the model about the pad token
|
|
if hasattr(models.llm.config, "pad_token_id") and models.llm.config.pad_token_id is None:
|
|
models.llm.config.pad_token_id = models.tokenizer.pad_token_id
|
|
|
|
logger.info("Llama 3.2 model loaded successfully")
|
|
socketio.emit('model_status', {'model': 'llm', 'status': 'loaded'})
|
|
socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 100, 'message': 'All models loaded successfully'})
|
|
socketio.emit('model_status', {'model': 'overall', 'status': 'loaded'})
|
|
except Exception as e:
|
|
logger.error(f"Error loading Llama 3.2 model: {str(e)}")
|
|
socketio.emit('model_status', {'model': 'llm', 'status': 'error', 'message': str(e)})
|
|
|
|
# Load models in a background thread
|
|
threading.Thread(target=load_models, daemon=True).start()
|
|
|
|
# Import routes and socket handlers
|
|
from api.routes import register_routes
|
|
from api.socket_handlers import register_handlers
|
|
|
|
# Register routes and socket handlers
|
|
register_routes(app)
|
|
register_handlers(socketio, app, models, active_conversations, user_queues, processing_threads, DEVICE)
|
|
|
|
# Run server if executed directly
|
|
if __name__ == '__main__':
|
|
port = int(os.environ.get('PORT', 5000))
|
|
debug_mode = os.environ.get('DEBUG', 'False').lower() == 'true'
|
|
logger.info(f"Starting server on port {port} (debug={debug_mode})")
|
|
socketio.run(app, host='0.0.0.0', port=port, debug=debug_mode, allow_unsafe_werkzeug=True) |