Frontend Fixed
This commit is contained in:
136
Backend/api/app.py
Normal file
136
Backend/api/app.py
Normal file
@@ -0,0 +1,136 @@
|
||||
import os
|
||||
import logging
|
||||
import threading
|
||||
from dataclasses import dataclass
|
||||
from flask import Flask
|
||||
from flask_socketio import SocketIO
|
||||
from flask_cors import CORS
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Configure device
|
||||
import torch
|
||||
if torch.cuda.is_available():
|
||||
DEVICE = "cuda"
|
||||
elif torch.backends.mps.is_available():
|
||||
DEVICE = "mps"
|
||||
else:
|
||||
DEVICE = "cpu"
|
||||
|
||||
logger.info(f"Using device: {DEVICE}")
|
||||
|
||||
# Initialize Flask app
|
||||
app = Flask(__name__, static_folder='../', static_url_path='')
|
||||
CORS(app)
|
||||
socketio = SocketIO(app, cors_allowed_origins="*", ping_timeout=120)
|
||||
|
||||
# Global variables for conversation state
|
||||
active_conversations = {}
|
||||
user_queues = {}
|
||||
processing_threads = {}
|
||||
|
||||
# Model storage
|
||||
@dataclass
|
||||
class AppModels:
|
||||
generator = None
|
||||
tokenizer = None
|
||||
llm = None
|
||||
whisperx_model = None
|
||||
whisperx_align_model = None
|
||||
whisperx_align_metadata = None
|
||||
last_language = None
|
||||
|
||||
models = AppModels()
|
||||
|
||||
def load_models():
|
||||
"""Load all required models"""
|
||||
from generator import load_csm_1b
|
||||
import whisperx
|
||||
import gc
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
global models
|
||||
|
||||
socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 0})
|
||||
|
||||
# CSM 1B loading
|
||||
try:
|
||||
socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 10, 'message': 'Loading CSM voice model'})
|
||||
models.generator = load_csm_1b(device=DEVICE)
|
||||
logger.info("CSM 1B model loaded successfully")
|
||||
socketio.emit('model_status', {'model': 'csm', 'status': 'loaded'})
|
||||
socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 33})
|
||||
if DEVICE == "cuda":
|
||||
torch.cuda.empty_cache()
|
||||
except Exception as e:
|
||||
import traceback
|
||||
error_details = traceback.format_exc()
|
||||
logger.error(f"Error loading CSM 1B model: {str(e)}\n{error_details}")
|
||||
socketio.emit('model_status', {'model': 'csm', 'status': 'error', 'message': str(e)})
|
||||
|
||||
# WhisperX loading
|
||||
try:
|
||||
socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 40, 'message': 'Loading speech recognition model'})
|
||||
# Use WhisperX for better transcription with timestamps
|
||||
# Use compute_type based on device
|
||||
compute_type = "float16" if DEVICE == "cuda" else "float32"
|
||||
|
||||
# Load the WhisperX model (smaller model for faster processing)
|
||||
models.whisperx_model = whisperx.load_model("small", DEVICE, compute_type=compute_type)
|
||||
|
||||
logger.info("WhisperX model loaded successfully")
|
||||
socketio.emit('model_status', {'model': 'asr', 'status': 'loaded'})
|
||||
socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 66})
|
||||
if DEVICE == "cuda":
|
||||
torch.cuda.empty_cache()
|
||||
except Exception as e:
|
||||
import traceback
|
||||
error_details = traceback.format_exc()
|
||||
logger.error(f"Error loading WhisperX model: {str(e)}\n{error_details}")
|
||||
socketio.emit('model_status', {'model': 'asr', 'status': 'error', 'message': str(e)})
|
||||
|
||||
# Llama loading
|
||||
try:
|
||||
socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 70, 'message': 'Loading language model'})
|
||||
models.llm = AutoModelForCausalLM.from_pretrained(
|
||||
"meta-llama/Llama-3.2-1B",
|
||||
device_map=DEVICE,
|
||||
torch_dtype=torch.bfloat16
|
||||
)
|
||||
models.tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B")
|
||||
|
||||
# Configure all special tokens
|
||||
models.tokenizer.pad_token = models.tokenizer.eos_token
|
||||
models.tokenizer.padding_side = "left" # For causal language modeling
|
||||
|
||||
# Inform the model about the pad token
|
||||
if hasattr(models.llm.config, "pad_token_id") and models.llm.config.pad_token_id is None:
|
||||
models.llm.config.pad_token_id = models.tokenizer.pad_token_id
|
||||
|
||||
logger.info("Llama 3.2 model loaded successfully")
|
||||
socketio.emit('model_status', {'model': 'llm', 'status': 'loaded'})
|
||||
socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 100, 'message': 'All models loaded successfully'})
|
||||
socketio.emit('model_status', {'model': 'overall', 'status': 'loaded'})
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading Llama 3.2 model: {str(e)}")
|
||||
socketio.emit('model_status', {'model': 'llm', 'status': 'error', 'message': str(e)})
|
||||
|
||||
# Load models in a background thread
|
||||
threading.Thread(target=load_models, daemon=True).start()
|
||||
|
||||
# Import routes and socket handlers
|
||||
from api.routes import register_routes
|
||||
from api.socket_handlers import register_handlers
|
||||
|
||||
# Register routes and socket handlers
|
||||
register_routes(app)
|
||||
register_handlers(socketio, app, models, active_conversations, user_queues, processing_threads, DEVICE)
|
||||
|
||||
# Run server if executed directly
|
||||
if __name__ == '__main__':
|
||||
port = int(os.environ.get('PORT', 5000))
|
||||
debug_mode = os.environ.get('DEBUG', 'False').lower() == 'true'
|
||||
logger.info(f"Starting server on port {port} (debug={debug_mode})")
|
||||
socketio.run(app, host='0.0.0.0', port=port, debug=debug_mode, allow_unsafe_werkzeug=True)
|
||||
Reference in New Issue
Block a user