234 lines
8.8 KiB
Python
234 lines
8.8 KiB
Python
from flask import Flask, request, jsonify
|
|
from flask_cors import CORS
|
|
from dotenv import load_dotenv
|
|
from train import compute_index
|
|
from models import load_model
|
|
from models import MLP
|
|
|
|
# Load environment variables from .env file
|
|
load_dotenv()
|
|
|
|
app = Flask(__name__)
|
|
# Enable CORS for all routes, origins, and methods
|
|
CORS(app, resources={
|
|
r"/*": {
|
|
"origins": "*",
|
|
"methods": ["GET", "POST", "PUT", "DELETE", "OPTIONS"],
|
|
"allow_headers": ["Content-Type", "Authorization", "Accept", "Origin", "X-Requested-With"]
|
|
}
|
|
})
|
|
import os
|
|
import threading
|
|
import json
|
|
import numpy as np # added
|
|
|
|
# ML imports are lazy to avoid heavy imports on simple runs
|
|
|
|
@app.route('/')
|
|
def home():
|
|
return "<h1>Welcome to the Flask App</h1><p>Try /get-data or /health endpoints.</p>"
|
|
|
|
@app.route('/get-data', methods=['GET'])
|
|
def get_data():
|
|
# Example GET request handler
|
|
data = {"message": "Hello from Flask!"}
|
|
return jsonify(data)
|
|
|
|
@app.route('/post-data', methods=['POST'])
|
|
def post_data():
|
|
# Example POST request handler
|
|
content = request.json
|
|
# Process content or call AI model here
|
|
response = {"you_sent": content}
|
|
return jsonify(response)
|
|
|
|
|
|
@app.route('/train', methods=['POST'])
|
|
def train_endpoint():
|
|
"""Trigger training. Expects JSON: {"data_root": "path/to/data", "epochs": 3}
|
|
Training runs in a background thread and saves model to model.pth in repo root.
|
|
"""
|
|
payload = request.json or {}
|
|
data_root = payload.get('data_root')
|
|
epochs = int(payload.get('epochs', 3))
|
|
if not data_root or not os.path.isdir(data_root):
|
|
return jsonify({"error": "data_root must be a valid directory path"}), 400
|
|
|
|
def _run_training():
|
|
from train import train
|
|
train(data_root, epochs=epochs)
|
|
|
|
t = threading.Thread(target=_run_training, daemon=True)
|
|
t.start()
|
|
return jsonify({"status": "training_started"})
|
|
|
|
@app.route('/health', methods=['GET'])
|
|
def health():
|
|
"""Return status of loaded ML artifacts (model, centers, preprocess_meta)."""
|
|
try:
|
|
from openmeteo_inference import init_inference
|
|
status = init_inference()
|
|
return jsonify({'ok': True, 'artifacts': status})
|
|
except Exception as e:
|
|
return jsonify({'ok': False, 'error': str(e)}), 500
|
|
|
|
@app.route('/predict', methods=['POST', 'GET'])
|
|
def predict_endpoint():
|
|
"""Predict route between two points given source and destination with lat and lon."""
|
|
|
|
# Default example values for GET
|
|
default_src_lat = 38.9
|
|
default_src_lon = -77.0
|
|
default_dst_lat = 38.95
|
|
default_dst_lon = -77.02
|
|
|
|
if request.method == 'GET':
|
|
# Read from query params or fall back to defaults
|
|
try:
|
|
src_lat = float(request.args.get('sourceLat', default_src_lat))
|
|
src_lon = float(request.args.get('sourceLon', default_src_lon))
|
|
dst_lat = float(request.args.get('destLat', default_dst_lat))
|
|
dst_lon = float(request.args.get('destLon', default_dst_lon))
|
|
except (TypeError, ValueError):
|
|
return jsonify({"error": "Invalid query parameters"}), 400
|
|
else: # POST
|
|
data = request.get_json(silent=True) or {}
|
|
source = data.get("source")
|
|
destination = data.get("destination")
|
|
|
|
if not source or not destination:
|
|
return jsonify({"error": "both 'source' and 'destination' fields are required"}), 400
|
|
|
|
try:
|
|
src_lat = float(source.get("lat"))
|
|
src_lon = float(source.get("lon"))
|
|
dst_lat = float(destination.get("lat"))
|
|
dst_lon = float(destination.get("lon"))
|
|
except (TypeError, ValueError):
|
|
return jsonify({"error": "invalid lat or lon values; must be numbers"}), 400
|
|
# load model (loader infers architecture from checkpoint)
|
|
try:
|
|
model = load_model('model.pth', MLP)
|
|
except Exception as e:
|
|
return jsonify({"error": "model load failed", "detail": str(e)}), 500
|
|
|
|
# infer expected input dim from model first linear weight
|
|
try:
|
|
input_dim = None
|
|
for v in model.state_dict().values():
|
|
if getattr(v, "dim", None) and v.dim() == 2:
|
|
input_dim = int(v.shape[1])
|
|
break
|
|
if input_dim is None:
|
|
input_dim = 2
|
|
except Exception:
|
|
input_dim = 2
|
|
|
|
# build feature vector of correct length and populate lat/lon using preprocess meta if available
|
|
feature_vector = np.zeros(int(input_dim), dtype=float)
|
|
meta_path = os.path.join(os.getcwd(), 'preprocess_meta.npz')
|
|
|
|
if os.path.exists(meta_path):
|
|
try:
|
|
meta = np.load(meta_path, allow_pickle=True)
|
|
cols = [str(x) for x in meta['feature_columns'].tolist()]
|
|
means = meta.get('means')
|
|
if means is not None and len(means) == input_dim:
|
|
feature_vector[:] = means
|
|
|
|
col_lower = [c.lower() for c in cols]
|
|
print(f"📋 Available columns: {col_lower[:10]}...") # Show first 10 columns
|
|
|
|
# Try to find and populate coordinate fields
|
|
coord_mappings = [
|
|
(('lat', 'latitude', 'src_lat', 'source_lat'), src_lat),
|
|
(('lon', 'lng', 'longitude', 'src_lon', 'source_lon'), src_lon),
|
|
(('dst_lat', 'dest_lat', 'destination_lat', 'end_lat'), dst_lat),
|
|
(('dst_lon', 'dest_lon', 'destination_lon', 'end_lon', 'dst_lng'), dst_lon)
|
|
]
|
|
|
|
for possible_names, value in coord_mappings:
|
|
for name in possible_names:
|
|
if name in col_lower:
|
|
idx = col_lower.index(name)
|
|
feature_vector[idx] = value
|
|
print(f"✅ Mapped {name} (index {idx}) = {value}")
|
|
break
|
|
|
|
# Calculate route features that might be useful
|
|
route_distance = ((dst_lat - src_lat)**2 + (dst_lon - src_lon)**2)**0.5
|
|
midpoint_lat = (src_lat + dst_lat) / 2
|
|
midpoint_lon = (src_lon + dst_lon) / 2
|
|
|
|
# Try to populate additional features that might exist
|
|
additional_features = {
|
|
'distance': route_distance,
|
|
'route_distance': route_distance,
|
|
'midpoint_lat': midpoint_lat,
|
|
'midpoint_lon': midpoint_lon,
|
|
'lat_diff': abs(dst_lat - src_lat),
|
|
'lon_diff': abs(dst_lon - src_lon)
|
|
}
|
|
|
|
for feature_name, feature_value in additional_features.items():
|
|
if feature_name in col_lower:
|
|
idx = col_lower.index(feature_name)
|
|
feature_vector[idx] = feature_value
|
|
print(f"✅ Mapped {feature_name} (index {idx}) = {feature_value}")
|
|
|
|
except Exception as e:
|
|
print(f"⚠️ Error processing metadata: {e}")
|
|
# Fallback to simple coordinate mapping
|
|
feature_vector[:] = 0.0
|
|
feature_vector[0] = src_lat
|
|
if input_dim > 1:
|
|
feature_vector[1] = src_lon
|
|
if input_dim > 2:
|
|
feature_vector[2] = dst_lat
|
|
if input_dim > 3:
|
|
feature_vector[3] = dst_lon
|
|
else:
|
|
print("⚠️ No preprocess_meta.npz found, using simple coordinate mapping")
|
|
# Simple fallback mapping
|
|
feature_vector[0] = src_lat
|
|
if input_dim > 1:
|
|
feature_vector[1] = src_lon
|
|
if input_dim > 2:
|
|
feature_vector[2] = dst_lat
|
|
if input_dim > 3:
|
|
feature_vector[3] = dst_lon
|
|
|
|
# Add some derived features to create more variation
|
|
if input_dim > 4:
|
|
feature_vector[4] = ((dst_lat - src_lat)**2 + (dst_lon - src_lon)**2)**0.5 # distance
|
|
if input_dim > 5:
|
|
feature_vector[5] = (src_lat + dst_lat) / 2 # midpoint lat
|
|
if input_dim > 6:
|
|
feature_vector[6] = (src_lon + dst_lon) / 2 # midpoint lon
|
|
|
|
# compute index using model
|
|
try:
|
|
print(f"Feature vector for prediction: {feature_vector[:8]}...") # Show first 8 values
|
|
print(f"Coordinates: src({src_lat}, {src_lon}) → dst({dst_lat}, {dst_lon})")
|
|
index = compute_index(model, feature_vector)
|
|
print(f"Computed index: {index}")
|
|
except Exception as e:
|
|
return jsonify({"error": "compute_index failed", "detail": str(e)}), 500
|
|
|
|
response_payload = {
|
|
"index": index,
|
|
"prediction": {},
|
|
"called_with": request.method,
|
|
"diagnostics": {"input_dim": int(input_dim)}
|
|
}
|
|
|
|
return jsonify(response_payload), 200
|
|
|
|
if __name__ == '__main__':
|
|
# eager load model/artifacts at startup (best-effort)
|
|
try:
|
|
from openmeteo_inference import init_inference
|
|
init_inference()
|
|
except Exception:
|
|
pass
|
|
app.run(debug=True) |