Files
HooHacks-12/Backend/api/app.py
2025-03-30 03:43:08 -04:00

136 lines
5.3 KiB
Python

import os
import logging
import threading
from dataclasses import dataclass
from flask import Flask
from flask_socketio import SocketIO
from flask_cors import CORS
# Configure logging
logging.basicConfig(level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# Configure device
import torch
if torch.cuda.is_available():
DEVICE = "cuda"
elif torch.backends.mps.is_available():
DEVICE = "mps"
else:
DEVICE = "cpu"
logger.info(f"Using device: {DEVICE}")
# Initialize Flask app
app = Flask(__name__, static_folder='../', static_url_path='')
CORS(app)
socketio = SocketIO(app, cors_allowed_origins="*", ping_timeout=120)
# Global variables for conversation state
active_conversations = {}
user_queues = {}
processing_threads = {}
# Model storage
@dataclass
class AppModels:
generator = None
tokenizer = None
llm = None
whisperx_model = None
whisperx_align_model = None
whisperx_align_metadata = None
last_language = None
models = AppModels()
def load_models():
"""Load all required models"""
from generator import load_csm_1b
import whisperx
import gc
from transformers import AutoModelForCausalLM, AutoTokenizer
global models
socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 0})
# CSM 1B loading
try:
socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 10, 'message': 'Loading CSM voice model'})
models.generator = load_csm_1b(device=DEVICE)
logger.info("CSM 1B model loaded successfully")
socketio.emit('model_status', {'model': 'csm', 'status': 'loaded'})
socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 33})
if DEVICE == "cuda":
torch.cuda.empty_cache()
except Exception as e:
import traceback
error_details = traceback.format_exc()
logger.error(f"Error loading CSM 1B model: {str(e)}\n{error_details}")
socketio.emit('model_status', {'model': 'csm', 'status': 'error', 'message': str(e)})
# WhisperX loading
try:
socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 40, 'message': 'Loading speech recognition model'})
# Use WhisperX for better transcription with timestamps
# Use compute_type based on device
compute_type = "float16" if DEVICE == "cuda" else "float32"
# Load the WhisperX model (smaller model for faster processing)
models.whisperx_model = whisperx.load_model("small", DEVICE, compute_type=compute_type)
logger.info("WhisperX model loaded successfully")
socketio.emit('model_status', {'model': 'asr', 'status': 'loaded'})
socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 66})
if DEVICE == "cuda":
torch.cuda.empty_cache()
except Exception as e:
import traceback
error_details = traceback.format_exc()
logger.error(f"Error loading WhisperX model: {str(e)}\n{error_details}")
socketio.emit('model_status', {'model': 'asr', 'status': 'error', 'message': str(e)})
# Llama loading
try:
socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 70, 'message': 'Loading language model'})
models.llm = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3.2-1B",
device_map=DEVICE,
torch_dtype=torch.bfloat16
)
models.tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B")
# Configure all special tokens
models.tokenizer.pad_token = models.tokenizer.eos_token
models.tokenizer.padding_side = "left" # For causal language modeling
# Inform the model about the pad token
if hasattr(models.llm.config, "pad_token_id") and models.llm.config.pad_token_id is None:
models.llm.config.pad_token_id = models.tokenizer.pad_token_id
logger.info("Llama 3.2 model loaded successfully")
socketio.emit('model_status', {'model': 'llm', 'status': 'loaded'})
socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 100, 'message': 'All models loaded successfully'})
socketio.emit('model_status', {'model': 'overall', 'status': 'loaded'})
except Exception as e:
logger.error(f"Error loading Llama 3.2 model: {str(e)}")
socketio.emit('model_status', {'model': 'llm', 'status': 'error', 'message': str(e)})
# Load models in a background thread
threading.Thread(target=load_models, daemon=True).start()
# Import routes and socket handlers
from api.routes import register_routes
from api.socket_handlers import register_handlers
# Register routes and socket handlers
register_routes(app)
register_handlers(socketio, app, models, active_conversations, user_queues, processing_threads, DEVICE)
# Run server if executed directly
if __name__ == '__main__':
port = int(os.environ.get('PORT', 5000))
debug_mode = os.environ.get('DEBUG', 'False').lower() == 'true'
logger.info(f"Starting server on port {port} (debug={debug_mode})")
socketio.run(app, host='0.0.0.0', port=port, debug=debug_mode, allow_unsafe_werkzeug=True)