parameters fixed

This commit is contained in:
Pranav Malladi
2025-09-28 06:35:03 -04:00
parent b5efec5a60
commit 9771d56620
2 changed files with 31 additions and 37 deletions

View File

@@ -74,41 +74,38 @@ def health():
@app.route('/predict', methods=['POST', 'GET'])
def predict_endpoint():
"""Predict route between two points given source and destination with lat and lon.
GET uses an example payload; POST accepts JSON with 'source' and 'destination'.
Both methods run the same model/index logic and return the same response format.
"""
example_payload = {
"source": {"lat": 38.9, "lon": -77.0},
"destination": {"lat": 38.95, "lon": -77.02}
}
"""Predict route between two points given source and destination with lat and lon."""
info = "This endpoint expects a POST with JSON body (GET returns example payload)."
note = (
"Use POST to receive a prediction. Example: curl -X POST -H 'Content-Type: application/json' "
"-d '{\"source\": {\"lat\": 38.9, \"lon\": -77.0}, \"destination\": {\"lat\": 38.95, \"lon\": -77.02}}' "
"http://127.0.0.1:5000/predict"
)
# Default example values for GET
default_src_lat = 38.9
default_src_lon = -77.0
default_dst_lat = 38.95
default_dst_lon = -77.02
# unify request data: GET -> example, POST -> request.json
if request.method == 'GET':
data = example_payload
else:
data = request.json or {}
# Read from query params or fall back to defaults
try:
src_lat = float(request.args.get('sourceLat', default_src_lat))
src_lon = float(request.args.get('sourceLon', default_src_lon))
dst_lat = float(request.args.get('destLat', default_dst_lat))
dst_lon = float(request.args.get('destLon', default_dst_lon))
except (TypeError, ValueError):
return jsonify({"error": "Invalid query parameters"}), 400
else: # POST
data = request.get_json(silent=True) or {}
source = data.get("source")
destination = data.get("destination")
source = data.get('source')
destination = data.get('destination')
if not source or not destination:
return jsonify({"error": "both 'source' and 'destination' fields are required"}), 400
try:
src_lat = float(source.get('lat'))
src_lon = float(source.get('lon'))
dst_lat = float(destination.get('lat'))
dst_lon = float(destination.get('lon'))
except (TypeError, ValueError):
return jsonify({"error": "invalid lat or lon values; must be numbers"}), 400
if not source or not destination:
return jsonify({"error": "both 'source' and 'destination' fields are required"}), 400
try:
src_lat = float(source.get("lat"))
src_lon = float(source.get("lon"))
dst_lat = float(destination.get("lat"))
dst_lon = float(destination.get("lon"))
except (TypeError, ValueError):
return jsonify({"error": "invalid lat or lon values; must be numbers"}), 400
# load model (loader infers architecture from checkpoint)
try:
model = load_model('model.pth', MLP)
@@ -211,10 +208,10 @@ def predict_endpoint():
# compute index using model
try:
print(f"🔍 Feature vector for prediction: {feature_vector[:8]}...") # Show first 8 values
print(f"📍 Coordinates: src({src_lat}, {src_lon}) → dst({dst_lat}, {dst_lon})")
print(f"Feature vector for prediction: {feature_vector[:8]}...") # Show first 8 values
print(f"Coordinates: src({src_lat}, {src_lon}) → dst({dst_lat}, {dst_lon})")
index = compute_index(model, feature_vector)
print(f"📊 Computed index: {index}")
print(f"Computed index: {index}")
except Exception as e:
return jsonify({"error": "compute_index failed", "detail": str(e)}), 500
@@ -222,10 +219,7 @@ def predict_endpoint():
"index": index,
"prediction": {},
"called_with": request.method,
"diagnostics": {"input_dim": int(input_dim)},
"example": example_payload,
"info": info,
"note": note
"diagnostics": {"input_dim": int(input_dim)}
}
return jsonify(response_payload), 200

Binary file not shown.