flask update

This commit is contained in:
samarthjain2023
2025-09-27 19:37:04 -04:00
parent c939e0a2aa
commit 6bdd8f0fe3
2 changed files with 61 additions and 19 deletions

View File

@@ -43,28 +43,67 @@ def train_endpoint():
return jsonify({"status": "training_started"}) return jsonify({"status": "training_started"})
@app.route('/predict', methods=['POST']) @app.route('/predict', methods=['POST', 'GET'])
def predict_endpoint(): def predict_endpoint():
"""Predict single uploaded image. Expects form-data with file field named 'image'.""" """Predict route between two points given source and destination with lat and lon.
if 'image' not in request.files:
return jsonify({"error": "no image uploaded (field 'image')"}), 400 Expectation:
img = request.files['image'] - POST with JSON: {"source": {"lat": .., "lon": ..}, "destination": {"lat": .., "lon": ..}}
tmp_path = os.path.join(os.getcwd(), 'tmp_upload.jpg') - GET returns usage instructions for quick browser testing.
img.save(tmp_path) """
if request.method == 'GET':
return jsonify({
"info": "This endpoint expects a POST with JSON body.",
"example": {
"source": {"lat": 38.9, "lon": -77.0},
"destination": {"lat": 38.95, "lon": -77.02}
},
"note": "Use POST to receive a prediction. Example: curl -X POST -H 'Content-Type: application/json' -d '{\"source\": {\"lat\": 38.9, \"lon\": -77.0}, \"destination\": {\"lat\": 38.95, \"lon\": -77.02}}' http://127.0.0.1:5000/predict"
}), 200
data = request.json or {}
source = data.get('source')
destination = data.get('destination')
if not source or not destination:
return jsonify({"error": "both 'source' and 'destination' fields are required"}), 400
try:
src_lat = float(source.get('lat'))
src_lon = float(source.get('lon'))
dst_lat = float(destination.get('lat'))
dst_lon = float(destination.get('lon'))
except (TypeError, ValueError):
return jsonify({"error": "invalid lat or lon values; must be numbers"}), 400
# Ensure compute_reroute exists and is callable
try:
from openweather_client import compute_reroute
except Exception as e:
return jsonify({
"error": "compute_reroute not found in openweather_client",
"detail": str(e),
"hint": "Provide openweather_client.compute_reroute or implement a callable that accepts (src_lat, src_lon, dst_lat, dst_lon)"
}), 500
if not callable(compute_reroute):
return jsonify({"error": "openweather_client.compute_reroute is not callable"}), 500
# Call compute_reroute with fallback strategies
try: try:
from inference import load_model, predict_image
model_path = os.path.join(os.getcwd(), 'model.pth')
if not os.path.exists(model_path):
return jsonify({"error": "no trained model found (run /train first)"}), 400
model, idx_to_class = load_model(model_path)
idx, conf = predict_image(model, tmp_path)
label = idx_to_class.get(idx) if idx_to_class else str(idx)
return jsonify({"label": label, "confidence": conf})
finally:
try: try:
os.remove(tmp_path) result = compute_reroute(src_lat, src_lon, dst_lat, dst_lon)
except Exception: except TypeError:
pass # fallback: single payload dict
payload = {'source': {'lat': src_lat, 'lon': src_lon}, 'destination': {'lat': dst_lat, 'lon': dst_lon}}
result = compute_reroute(payload)
# Normalize response
if isinstance(result, dict):
return jsonify(result)
else:
return jsonify({"result": result})
except Exception as e:
return jsonify({"error": "compute_reroute invocation failed", "detail": str(e)}), 500
@app.route('/') @app.route('/')
def home(): def home():
@@ -136,3 +175,6 @@ if __name__ == '__main__':
# ai_response = f"AI received: {user_input}" # ai_response = f"AI received: {user_input}"
# return jsonify({"response": ai_response}) # return jsonify({"response": ai_response})
# ai_response = f"AI received: {user_input}"
# return jsonify({"response": ai_response"})