Flask COrs
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
from flask import Flask, request, jsonify
|
||||
from flask_cors import CORS
|
||||
from dotenv import load_dotenv
|
||||
from train import compute_index
|
||||
from models import load_model
|
||||
@@ -8,6 +9,14 @@ from models import MLP
|
||||
load_dotenv()
|
||||
|
||||
app = Flask(__name__)
|
||||
# Enable CORS for all routes, origins, and methods
|
||||
CORS(app, resources={
|
||||
r"/*": {
|
||||
"origins": "*",
|
||||
"methods": ["GET", "POST", "PUT", "DELETE", "OPTIONS"],
|
||||
"allow_headers": ["Content-Type", "Authorization", "Accept", "Origin", "X-Requested-With"]
|
||||
}
|
||||
})
|
||||
import os
|
||||
import threading
|
||||
import json
|
||||
@@ -121,6 +130,7 @@ def predict_endpoint():
|
||||
# build feature vector of correct length and populate lat/lon using preprocess meta if available
|
||||
feature_vector = np.zeros(int(input_dim), dtype=float)
|
||||
meta_path = os.path.join(os.getcwd(), 'preprocess_meta.npz')
|
||||
|
||||
if os.path.exists(meta_path):
|
||||
try:
|
||||
meta = np.load(meta_path, allow_pickle=True)
|
||||
@@ -128,33 +138,83 @@ def predict_endpoint():
|
||||
means = meta.get('means')
|
||||
if means is not None and len(means) == input_dim:
|
||||
feature_vector[:] = means
|
||||
|
||||
col_lower = [c.lower() for c in cols]
|
||||
if 'lat' in col_lower:
|
||||
feature_vector[col_lower.index('lat')] = src_lat
|
||||
elif 'latitude' in col_lower:
|
||||
feature_vector[col_lower.index('latitude')] = src_lat
|
||||
else:
|
||||
feature_vector[0] = src_lat
|
||||
if 'lon' in col_lower:
|
||||
feature_vector[col_lower.index('lon')] = src_lon
|
||||
elif 'longitude' in col_lower:
|
||||
feature_vector[col_lower.index('longitude')] = src_lon
|
||||
else:
|
||||
if input_dim > 1:
|
||||
feature_vector[1] = src_lon
|
||||
except Exception:
|
||||
print(f"📋 Available columns: {col_lower[:10]}...") # Show first 10 columns
|
||||
|
||||
# Try to find and populate coordinate fields
|
||||
coord_mappings = [
|
||||
(('lat', 'latitude', 'src_lat', 'source_lat'), src_lat),
|
||||
(('lon', 'lng', 'longitude', 'src_lon', 'source_lon'), src_lon),
|
||||
(('dst_lat', 'dest_lat', 'destination_lat', 'end_lat'), dst_lat),
|
||||
(('dst_lon', 'dest_lon', 'destination_lon', 'end_lon', 'dst_lng'), dst_lon)
|
||||
]
|
||||
|
||||
for possible_names, value in coord_mappings:
|
||||
for name in possible_names:
|
||||
if name in col_lower:
|
||||
idx = col_lower.index(name)
|
||||
feature_vector[idx] = value
|
||||
print(f"✅ Mapped {name} (index {idx}) = {value}")
|
||||
break
|
||||
|
||||
# Calculate route features that might be useful
|
||||
route_distance = ((dst_lat - src_lat)**2 + (dst_lon - src_lon)**2)**0.5
|
||||
midpoint_lat = (src_lat + dst_lat) / 2
|
||||
midpoint_lon = (src_lon + dst_lon) / 2
|
||||
|
||||
# Try to populate additional features that might exist
|
||||
additional_features = {
|
||||
'distance': route_distance,
|
||||
'route_distance': route_distance,
|
||||
'midpoint_lat': midpoint_lat,
|
||||
'midpoint_lon': midpoint_lon,
|
||||
'lat_diff': abs(dst_lat - src_lat),
|
||||
'lon_diff': abs(dst_lon - src_lon)
|
||||
}
|
||||
|
||||
for feature_name, feature_value in additional_features.items():
|
||||
if feature_name in col_lower:
|
||||
idx = col_lower.index(feature_name)
|
||||
feature_vector[idx] = feature_value
|
||||
print(f"✅ Mapped {feature_name} (index {idx}) = {feature_value}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"⚠️ Error processing metadata: {e}")
|
||||
# Fallback to simple coordinate mapping
|
||||
feature_vector[:] = 0.0
|
||||
feature_vector[0] = src_lat
|
||||
if input_dim > 1:
|
||||
feature_vector[1] = src_lon
|
||||
if input_dim > 2:
|
||||
feature_vector[2] = dst_lat
|
||||
if input_dim > 3:
|
||||
feature_vector[3] = dst_lon
|
||||
else:
|
||||
print("⚠️ No preprocess_meta.npz found, using simple coordinate mapping")
|
||||
# Simple fallback mapping
|
||||
feature_vector[0] = src_lat
|
||||
if input_dim > 1:
|
||||
feature_vector[1] = src_lon
|
||||
if input_dim > 2:
|
||||
feature_vector[2] = dst_lat
|
||||
if input_dim > 3:
|
||||
feature_vector[3] = dst_lon
|
||||
|
||||
# Add some derived features to create more variation
|
||||
if input_dim > 4:
|
||||
feature_vector[4] = ((dst_lat - src_lat)**2 + (dst_lon - src_lon)**2)**0.5 # distance
|
||||
if input_dim > 5:
|
||||
feature_vector[5] = (src_lat + dst_lat) / 2 # midpoint lat
|
||||
if input_dim > 6:
|
||||
feature_vector[6] = (src_lon + dst_lon) / 2 # midpoint lon
|
||||
|
||||
# compute index using model
|
||||
try:
|
||||
print(f"🔍 Feature vector for prediction: {feature_vector[:8]}...") # Show first 8 values
|
||||
print(f"📍 Coordinates: src({src_lat}, {src_lon}) → dst({dst_lat}, {dst_lon})")
|
||||
index = compute_index(model, feature_vector)
|
||||
print(f"📊 Computed index: {index}")
|
||||
except Exception as e:
|
||||
return jsonify({"error": "compute_index failed", "detail": str(e)}), 500
|
||||
|
||||
|
||||
Reference in New Issue
Block a user