Files
VTHacks13/roadcast/app.py
Pranav Malladi 9771d56620 parameters fixed
2025-09-28 06:35:03 -04:00

234 lines
8.8 KiB
Python

from flask import Flask, request, jsonify
from flask_cors import CORS
from dotenv import load_dotenv
from train import compute_index
from models import load_model
from models import MLP
# Load environment variables from .env file
load_dotenv()
app = Flask(__name__)
# Enable CORS for all routes, origins, and methods
CORS(app, resources={
r"/*": {
"origins": "*",
"methods": ["GET", "POST", "PUT", "DELETE", "OPTIONS"],
"allow_headers": ["Content-Type", "Authorization", "Accept", "Origin", "X-Requested-With"]
}
})
import os
import threading
import json
import numpy as np # added
# ML imports are lazy to avoid heavy imports on simple runs
@app.route('/')
def home():
return "<h1>Welcome to the Flask App</h1><p>Try /get-data or /health endpoints.</p>"
@app.route('/get-data', methods=['GET'])
def get_data():
# Example GET request handler
data = {"message": "Hello from Flask!"}
return jsonify(data)
@app.route('/post-data', methods=['POST'])
def post_data():
# Example POST request handler
content = request.json
# Process content or call AI model here
response = {"you_sent": content}
return jsonify(response)
@app.route('/train', methods=['POST'])
def train_endpoint():
"""Trigger training. Expects JSON: {"data_root": "path/to/data", "epochs": 3}
Training runs in a background thread and saves model to model.pth in repo root.
"""
payload = request.json or {}
data_root = payload.get('data_root')
epochs = int(payload.get('epochs', 3))
if not data_root or not os.path.isdir(data_root):
return jsonify({"error": "data_root must be a valid directory path"}), 400
def _run_training():
from train import train
train(data_root, epochs=epochs)
t = threading.Thread(target=_run_training, daemon=True)
t.start()
return jsonify({"status": "training_started"})
@app.route('/health', methods=['GET'])
def health():
"""Return status of loaded ML artifacts (model, centers, preprocess_meta)."""
try:
from openmeteo_inference import init_inference
status = init_inference()
return jsonify({'ok': True, 'artifacts': status})
except Exception as e:
return jsonify({'ok': False, 'error': str(e)}), 500
@app.route('/predict', methods=['POST', 'GET'])
def predict_endpoint():
"""Predict route between two points given source and destination with lat and lon."""
# Default example values for GET
default_src_lat = 38.9
default_src_lon = -77.0
default_dst_lat = 38.95
default_dst_lon = -77.02
if request.method == 'GET':
# Read from query params or fall back to defaults
try:
src_lat = float(request.args.get('sourceLat', default_src_lat))
src_lon = float(request.args.get('sourceLon', default_src_lon))
dst_lat = float(request.args.get('destLat', default_dst_lat))
dst_lon = float(request.args.get('destLon', default_dst_lon))
except (TypeError, ValueError):
return jsonify({"error": "Invalid query parameters"}), 400
else: # POST
data = request.get_json(silent=True) or {}
source = data.get("source")
destination = data.get("destination")
if not source or not destination:
return jsonify({"error": "both 'source' and 'destination' fields are required"}), 400
try:
src_lat = float(source.get("lat"))
src_lon = float(source.get("lon"))
dst_lat = float(destination.get("lat"))
dst_lon = float(destination.get("lon"))
except (TypeError, ValueError):
return jsonify({"error": "invalid lat or lon values; must be numbers"}), 400
# load model (loader infers architecture from checkpoint)
try:
model = load_model('model.pth', MLP)
except Exception as e:
return jsonify({"error": "model load failed", "detail": str(e)}), 500
# infer expected input dim from model first linear weight
try:
input_dim = None
for v in model.state_dict().values():
if getattr(v, "dim", None) and v.dim() == 2:
input_dim = int(v.shape[1])
break
if input_dim is None:
input_dim = 2
except Exception:
input_dim = 2
# build feature vector of correct length and populate lat/lon using preprocess meta if available
feature_vector = np.zeros(int(input_dim), dtype=float)
meta_path = os.path.join(os.getcwd(), 'preprocess_meta.npz')
if os.path.exists(meta_path):
try:
meta = np.load(meta_path, allow_pickle=True)
cols = [str(x) for x in meta['feature_columns'].tolist()]
means = meta.get('means')
if means is not None and len(means) == input_dim:
feature_vector[:] = means
col_lower = [c.lower() for c in cols]
print(f"📋 Available columns: {col_lower[:10]}...") # Show first 10 columns
# Try to find and populate coordinate fields
coord_mappings = [
(('lat', 'latitude', 'src_lat', 'source_lat'), src_lat),
(('lon', 'lng', 'longitude', 'src_lon', 'source_lon'), src_lon),
(('dst_lat', 'dest_lat', 'destination_lat', 'end_lat'), dst_lat),
(('dst_lon', 'dest_lon', 'destination_lon', 'end_lon', 'dst_lng'), dst_lon)
]
for possible_names, value in coord_mappings:
for name in possible_names:
if name in col_lower:
idx = col_lower.index(name)
feature_vector[idx] = value
print(f"✅ Mapped {name} (index {idx}) = {value}")
break
# Calculate route features that might be useful
route_distance = ((dst_lat - src_lat)**2 + (dst_lon - src_lon)**2)**0.5
midpoint_lat = (src_lat + dst_lat) / 2
midpoint_lon = (src_lon + dst_lon) / 2
# Try to populate additional features that might exist
additional_features = {
'distance': route_distance,
'route_distance': route_distance,
'midpoint_lat': midpoint_lat,
'midpoint_lon': midpoint_lon,
'lat_diff': abs(dst_lat - src_lat),
'lon_diff': abs(dst_lon - src_lon)
}
for feature_name, feature_value in additional_features.items():
if feature_name in col_lower:
idx = col_lower.index(feature_name)
feature_vector[idx] = feature_value
print(f"✅ Mapped {feature_name} (index {idx}) = {feature_value}")
except Exception as e:
print(f"⚠️ Error processing metadata: {e}")
# Fallback to simple coordinate mapping
feature_vector[:] = 0.0
feature_vector[0] = src_lat
if input_dim > 1:
feature_vector[1] = src_lon
if input_dim > 2:
feature_vector[2] = dst_lat
if input_dim > 3:
feature_vector[3] = dst_lon
else:
print("⚠️ No preprocess_meta.npz found, using simple coordinate mapping")
# Simple fallback mapping
feature_vector[0] = src_lat
if input_dim > 1:
feature_vector[1] = src_lon
if input_dim > 2:
feature_vector[2] = dst_lat
if input_dim > 3:
feature_vector[3] = dst_lon
# Add some derived features to create more variation
if input_dim > 4:
feature_vector[4] = ((dst_lat - src_lat)**2 + (dst_lon - src_lon)**2)**0.5 # distance
if input_dim > 5:
feature_vector[5] = (src_lat + dst_lat) / 2 # midpoint lat
if input_dim > 6:
feature_vector[6] = (src_lon + dst_lon) / 2 # midpoint lon
# compute index using model
try:
print(f"Feature vector for prediction: {feature_vector[:8]}...") # Show first 8 values
print(f"Coordinates: src({src_lat}, {src_lon}) → dst({dst_lat}, {dst_lon})")
index = compute_index(model, feature_vector)
print(f"Computed index: {index}")
except Exception as e:
return jsonify({"error": "compute_index failed", "detail": str(e)}), 500
response_payload = {
"index": index,
"prediction": {},
"called_with": request.method,
"diagnostics": {"input_dim": int(input_dim)}
}
return jsonify(response_payload), 200
if __name__ == '__main__':
# eager load model/artifacts at startup (best-effort)
try:
from openmeteo_inference import init_inference
init_inference()
except Exception:
pass
app.run(debug=True)