diff --git a/roadcast/app.py b/roadcast/app.py index a2631f1..86ee1d9 100644 --- a/roadcast/app.py +++ b/roadcast/app.py @@ -74,41 +74,38 @@ def health(): @app.route('/predict', methods=['POST', 'GET']) def predict_endpoint(): - """Predict route between two points given source and destination with lat and lon. - GET uses an example payload; POST accepts JSON with 'source' and 'destination'. - Both methods run the same model/index logic and return the same response format. - """ - example_payload = { - "source": {"lat": 38.9, "lon": -77.0}, - "destination": {"lat": 38.95, "lon": -77.02} - } + """Predict route between two points given source and destination with lat and lon.""" - info = "This endpoint expects a POST with JSON body (GET returns example payload)." - note = ( - "Use POST to receive a prediction. Example: curl -X POST -H 'Content-Type: application/json' " - "-d '{\"source\": {\"lat\": 38.9, \"lon\": -77.0}, \"destination\": {\"lat\": 38.95, \"lon\": -77.02}}' " - "http://127.0.0.1:5000/predict" - ) + # Default example values for GET + default_src_lat = 38.9 + default_src_lon = -77.0 + default_dst_lat = 38.95 + default_dst_lon = -77.02 - # unify request data: GET -> example, POST -> request.json if request.method == 'GET': - data = example_payload - else: - data = request.json or {} + # Read from query params or fall back to defaults + try: + src_lat = float(request.args.get('sourceLat', default_src_lat)) + src_lon = float(request.args.get('sourceLon', default_src_lon)) + dst_lat = float(request.args.get('destLat', default_dst_lat)) + dst_lon = float(request.args.get('destLon', default_dst_lon)) + except (TypeError, ValueError): + return jsonify({"error": "Invalid query parameters"}), 400 + else: # POST + data = request.get_json(silent=True) or {} + source = data.get("source") + destination = data.get("destination") - source = data.get('source') - destination = data.get('destination') - if not source or not destination: - return jsonify({"error": "both 'source' and 'destination' fields are required"}), 400 - - try: - src_lat = float(source.get('lat')) - src_lon = float(source.get('lon')) - dst_lat = float(destination.get('lat')) - dst_lon = float(destination.get('lon')) - except (TypeError, ValueError): - return jsonify({"error": "invalid lat or lon values; must be numbers"}), 400 + if not source or not destination: + return jsonify({"error": "both 'source' and 'destination' fields are required"}), 400 + try: + src_lat = float(source.get("lat")) + src_lon = float(source.get("lon")) + dst_lat = float(destination.get("lat")) + dst_lon = float(destination.get("lon")) + except (TypeError, ValueError): + return jsonify({"error": "invalid lat or lon values; must be numbers"}), 400 # load model (loader infers architecture from checkpoint) try: model = load_model('model.pth', MLP) @@ -211,10 +208,10 @@ def predict_endpoint(): # compute index using model try: - print(f"🔍 Feature vector for prediction: {feature_vector[:8]}...") # Show first 8 values - print(f"📍 Coordinates: src({src_lat}, {src_lon}) → dst({dst_lat}, {dst_lon})") + print(f"Feature vector for prediction: {feature_vector[:8]}...") # Show first 8 values + print(f"Coordinates: src({src_lat}, {src_lon}) → dst({dst_lat}, {dst_lon})") index = compute_index(model, feature_vector) - print(f"📊 Computed index: {index}") + print(f"Computed index: {index}") except Exception as e: return jsonify({"error": "compute_index failed", "detail": str(e)}), 500 @@ -222,10 +219,7 @@ def predict_endpoint(): "index": index, "prediction": {}, "called_with": request.method, - "diagnostics": {"input_dim": int(input_dim)}, - "example": example_payload, - "info": info, - "note": note + "diagnostics": {"input_dim": int(input_dim)} } return jsonify(response_payload), 200 diff --git a/roadcast/model.pth b/roadcast/model.pth index d524eb8..9f2777c 100644 Binary files a/roadcast/model.pth and b/roadcast/model.pth differ