parameters fixed

This commit is contained in:
Pranav Malladi
2025-09-28 06:35:03 -04:00
parent b5efec5a60
commit 9771d56620
2 changed files with 31 additions and 37 deletions

View File

@@ -74,41 +74,38 @@ def health():
@app.route('/predict', methods=['POST', 'GET']) @app.route('/predict', methods=['POST', 'GET'])
def predict_endpoint(): def predict_endpoint():
"""Predict route between two points given source and destination with lat and lon. """Predict route between two points given source and destination with lat and lon."""
GET uses an example payload; POST accepts JSON with 'source' and 'destination'.
Both methods run the same model/index logic and return the same response format.
"""
example_payload = {
"source": {"lat": 38.9, "lon": -77.0},
"destination": {"lat": 38.95, "lon": -77.02}
}
info = "This endpoint expects a POST with JSON body (GET returns example payload)." # Default example values for GET
note = ( default_src_lat = 38.9
"Use POST to receive a prediction. Example: curl -X POST -H 'Content-Type: application/json' " default_src_lon = -77.0
"-d '{\"source\": {\"lat\": 38.9, \"lon\": -77.0}, \"destination\": {\"lat\": 38.95, \"lon\": -77.02}}' " default_dst_lat = 38.95
"http://127.0.0.1:5000/predict" default_dst_lon = -77.02
)
# unify request data: GET -> example, POST -> request.json
if request.method == 'GET': if request.method == 'GET':
data = example_payload # Read from query params or fall back to defaults
else: try:
data = request.json or {} src_lat = float(request.args.get('sourceLat', default_src_lat))
src_lon = float(request.args.get('sourceLon', default_src_lon))
dst_lat = float(request.args.get('destLat', default_dst_lat))
dst_lon = float(request.args.get('destLon', default_dst_lon))
except (TypeError, ValueError):
return jsonify({"error": "Invalid query parameters"}), 400
else: # POST
data = request.get_json(silent=True) or {}
source = data.get("source")
destination = data.get("destination")
source = data.get('source') if not source or not destination:
destination = data.get('destination') return jsonify({"error": "both 'source' and 'destination' fields are required"}), 400
if not source or not destination:
return jsonify({"error": "both 'source' and 'destination' fields are required"}), 400
try:
src_lat = float(source.get('lat'))
src_lon = float(source.get('lon'))
dst_lat = float(destination.get('lat'))
dst_lon = float(destination.get('lon'))
except (TypeError, ValueError):
return jsonify({"error": "invalid lat or lon values; must be numbers"}), 400
try:
src_lat = float(source.get("lat"))
src_lon = float(source.get("lon"))
dst_lat = float(destination.get("lat"))
dst_lon = float(destination.get("lon"))
except (TypeError, ValueError):
return jsonify({"error": "invalid lat or lon values; must be numbers"}), 400
# load model (loader infers architecture from checkpoint) # load model (loader infers architecture from checkpoint)
try: try:
model = load_model('model.pth', MLP) model = load_model('model.pth', MLP)
@@ -211,10 +208,10 @@ def predict_endpoint():
# compute index using model # compute index using model
try: try:
print(f"🔍 Feature vector for prediction: {feature_vector[:8]}...") # Show first 8 values print(f"Feature vector for prediction: {feature_vector[:8]}...") # Show first 8 values
print(f"📍 Coordinates: src({src_lat}, {src_lon}) → dst({dst_lat}, {dst_lon})") print(f"Coordinates: src({src_lat}, {src_lon}) → dst({dst_lat}, {dst_lon})")
index = compute_index(model, feature_vector) index = compute_index(model, feature_vector)
print(f"📊 Computed index: {index}") print(f"Computed index: {index}")
except Exception as e: except Exception as e:
return jsonify({"error": "compute_index failed", "detail": str(e)}), 500 return jsonify({"error": "compute_index failed", "detail": str(e)}), 500
@@ -222,10 +219,7 @@ def predict_endpoint():
"index": index, "index": index,
"prediction": {}, "prediction": {},
"called_with": request.method, "called_with": request.method,
"diagnostics": {"input_dim": int(input_dim)}, "diagnostics": {"input_dim": int(input_dim)}
"example": example_payload,
"info": info,
"note": note
} }
return jsonify(response_payload), 200 return jsonify(response_payload), 200

Binary file not shown.