fixed flask run
This commit is contained in:
122
roadcast/app.py
122
roadcast/app.py
@@ -1,8 +1,78 @@
|
|||||||
|
from flask import Flask, request, jsonify
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
from openmeteo_client import compute_index
|
||||||
|
|
||||||
|
# Load environment variables from .env file
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
app = Flask(__name__)
|
||||||
|
import os
|
||||||
|
import threading
|
||||||
|
import json
|
||||||
|
|
||||||
|
# ML imports are lazy to avoid heavy imports on simple runs
|
||||||
|
|
||||||
|
@app.route('/')
|
||||||
|
def home():
|
||||||
|
return "<h1>Welcome to the Flask App</h1><p>Try /get-data or /health endpoints.</p>"
|
||||||
|
|
||||||
|
@app.route('/get-data', methods=['GET'])
|
||||||
|
def get_data():
|
||||||
|
# Example GET request handler
|
||||||
|
data = {"message": "Hello from Flask!"}
|
||||||
|
return jsonify(data)
|
||||||
|
|
||||||
|
@app.route('/post-data', methods=['POST'])
|
||||||
|
def post_data():
|
||||||
|
# Example POST request handler
|
||||||
|
content = request.json
|
||||||
|
# Process content or call AI model here
|
||||||
|
response = {"you_sent": content}
|
||||||
|
return jsonify(response)
|
||||||
|
|
||||||
|
|
||||||
|
@app.route('/train', methods=['POST'])
|
||||||
|
def train_endpoint():
|
||||||
|
"""Trigger training. Expects JSON: {"data_root": "path/to/data", "epochs": 3}
|
||||||
|
Training runs in a background thread and saves model to model.pth in repo root.
|
||||||
|
"""
|
||||||
|
payload = request.json or {}
|
||||||
|
data_root = payload.get('data_root')
|
||||||
|
epochs = int(payload.get('epochs', 3))
|
||||||
|
if not data_root or not os.path.isdir(data_root):
|
||||||
|
return jsonify({"error": "data_root must be a valid directory path"}), 400
|
||||||
|
|
||||||
|
def _run_training():
|
||||||
|
from train import train
|
||||||
|
train(data_root, epochs=epochs)
|
||||||
|
|
||||||
|
t = threading.Thread(target=_run_training, daemon=True)
|
||||||
|
t.start()
|
||||||
|
return jsonify({"status": "training_started"})
|
||||||
|
|
||||||
|
@app.route('/health', methods=['GET'])
|
||||||
|
def health():
|
||||||
|
"""Return status of loaded ML artifacts (model, centers, preprocess_meta)."""
|
||||||
|
try:
|
||||||
|
from openmeteo_inference import init_inference
|
||||||
|
status = init_inference()
|
||||||
|
return jsonify({'ok': True, 'artifacts': status})
|
||||||
|
except Exception as e:
|
||||||
|
return jsonify({'ok': False, 'error': str(e)}), 500
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# eager load model/artifacts at startup (best-effort)
|
||||||
|
try:
|
||||||
|
from openmeteo_inference import init_inference
|
||||||
|
init_inference()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
app.run(debug=True)
|
||||||
|
|
||||||
@app.route('/predict', methods=['POST', 'GET'])
|
@app.route('/predict', methods=['POST', 'GET'])
|
||||||
def predict_endpoint():
|
def predict_endpoint():
|
||||||
"""Predict route between two points given source and destination with lat and lon.
|
"""Predict route between two points given source and destination with lat and lon.
|
||||||
|
|
||||||
|
|
||||||
Expectation:
|
Expectation:
|
||||||
- POST with JSON: {"source": {"lat": .., "lon": ..}, "destination": {"lat": .., "lon": ..}}
|
- POST with JSON: {"source": {"lat": .., "lon": ..}, "destination": {"lat": .., "lon": ..}}
|
||||||
- GET returns usage instructions for quick browser testing.
|
- GET returns usage instructions for quick browser testing.
|
||||||
@@ -18,18 +88,46 @@ def predict_endpoint():
|
|||||||
"http://127.0.0.1:5000/predict"
|
"http://127.0.0.1:5000/predict"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
if request.method == 'GET':
|
if request.method == 'GET':
|
||||||
return jsonify({"info": info, "example": example_payload, "note": note}), 200
|
# Return the same structure as POST but without prediction
|
||||||
|
# response_payload = {
|
||||||
|
# "index": None,
|
||||||
|
# "prediction": {},
|
||||||
|
# "called_with": "GET",
|
||||||
|
# "diagnostics": {},
|
||||||
|
# "example": example_payload,
|
||||||
|
# "info": info,
|
||||||
|
# "note": note
|
||||||
|
# }
|
||||||
|
|
||||||
|
# For GET request, compute the road risk index using the example coordinates
|
||||||
|
src_lat = example_payload['source']['lat']
|
||||||
|
src_lon = example_payload['source']['lon']
|
||||||
|
dst_lat = example_payload['destination']['lat']
|
||||||
|
dst_lon = example_payload['destination']['lon']
|
||||||
|
|
||||||
|
# Use the compute_index function to get the road risk index
|
||||||
|
index = compute_index(src_lat, src_lon)
|
||||||
|
|
||||||
|
# Prepare the response payload
|
||||||
|
response_payload = {
|
||||||
|
"index": index, # The computed index here
|
||||||
|
"prediction": {},
|
||||||
|
"called_with": "GET",
|
||||||
|
"diagnostics": {},
|
||||||
|
"example": example_payload,
|
||||||
|
"info": info,
|
||||||
|
"note": note
|
||||||
|
}
|
||||||
|
return jsonify(response_payload), 200
|
||||||
|
|
||||||
|
# POST request logic
|
||||||
data = request.json or {}
|
data = request.json or {}
|
||||||
source = data.get('source')
|
source = data.get('source')
|
||||||
destination = data.get('destination')
|
destination = data.get('destination')
|
||||||
if not source or not destination:
|
if not source or not destination:
|
||||||
return jsonify({"error": "both 'source' and 'destination' fields are required"}), 400
|
return jsonify({"error": "both 'source' and 'destination' fields are required"}), 400
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
src_lat = float(source.get('lat'))
|
src_lat = float(source.get('lat'))
|
||||||
src_lon = float(source.get('lon'))
|
src_lon = float(source.get('lon'))
|
||||||
@@ -38,7 +136,6 @@ def predict_endpoint():
|
|||||||
except (TypeError, ValueError):
|
except (TypeError, ValueError):
|
||||||
return jsonify({"error": "invalid lat or lon values; must be numbers"}), 400
|
return jsonify({"error": "invalid lat or lon values; must be numbers"}), 400
|
||||||
|
|
||||||
|
|
||||||
# Ensure compute_reroute exists and is callable
|
# Ensure compute_reroute exists and is callable
|
||||||
try:
|
try:
|
||||||
from openmeteo_client import compute_reroute
|
from openmeteo_client import compute_reroute
|
||||||
@@ -50,11 +147,9 @@ def predict_endpoint():
|
|||||||
"(Open-Meteo does not need an API key)"
|
"(Open-Meteo does not need an API key)"
|
||||||
}), 500
|
}), 500
|
||||||
|
|
||||||
|
|
||||||
if not callable(compute_reroute):
|
if not callable(compute_reroute):
|
||||||
return jsonify({"error": "openmeteo_client.compute_reroute is not callable"}), 500
|
return jsonify({"error": "openmeteo_client.compute_reroute is not callable"}), 500
|
||||||
|
|
||||||
|
|
||||||
def _extract_index(res):
|
def _extract_index(res):
|
||||||
if res is None:
|
if res is None:
|
||||||
return None
|
return None
|
||||||
@@ -69,20 +164,17 @@ def predict_endpoint():
|
|||||||
return res[k]
|
return res[k]
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
# Call compute_reroute (Open-Meteo requires no API key)
|
# Call compute_reroute (Open-Meteo requires no API key)
|
||||||
try:
|
try:
|
||||||
result = compute_reroute(src_lat, src_lon, dst_lat, dst_lon)
|
result = compute_reroute(src_lat, src_lon, dst_lat, dst_lon)
|
||||||
called_with = "positional"
|
called_with = "positional"
|
||||||
|
|
||||||
|
|
||||||
diagnostics = {"type": type(result).__name__}
|
diagnostics = {"type": type(result).__name__}
|
||||||
try:
|
try:
|
||||||
diagnostics["repr"] = repr(result)[:1000]
|
diagnostics["repr"] = repr(result)[:1000]
|
||||||
except Exception:
|
except Exception:
|
||||||
diagnostics["repr"] = "<unrepr-able>"
|
diagnostics["repr"] = "<unrepr-able>"
|
||||||
|
|
||||||
|
|
||||||
# Normalize return types
|
# Normalize return types
|
||||||
if isinstance(result, (list, tuple)):
|
if isinstance(result, (list, tuple)):
|
||||||
idx = None
|
idx = None
|
||||||
@@ -102,7 +194,6 @@ def predict_endpoint():
|
|||||||
index = None
|
index = None
|
||||||
prediction = {"value": result}
|
prediction = {"value": result}
|
||||||
|
|
||||||
|
|
||||||
response_payload = {
|
response_payload = {
|
||||||
"index": index,
|
"index": index,
|
||||||
"prediction": prediction,
|
"prediction": prediction,
|
||||||
@@ -113,7 +204,6 @@ def predict_endpoint():
|
|||||||
"note": note
|
"note": note
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
# Add warning if no routing/index info found
|
# Add warning if no routing/index info found
|
||||||
expected_keys = ('route', 'path', 'distance', 'directions', 'index', 'idx', 'cluster')
|
expected_keys = ('route', 'path', 'distance', 'directions', 'index', 'idx', 'cluster')
|
||||||
if (not isinstance(prediction, dict) or not any(k in prediction for k in expected_keys)) and index is None:
|
if (not isinstance(prediction, dict) or not any(k in prediction for k in expected_keys)) and index is None:
|
||||||
@@ -122,9 +212,13 @@ def predict_endpoint():
|
|||||||
"See diagnostics for details."
|
"See diagnostics for details."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
return jsonify(response_payload), 200
|
||||||
|
|
||||||
return jsonify(response_payload)
|
except Exception as e:
|
||||||
|
return jsonify({
|
||||||
|
"error": "Error processing the request",
|
||||||
|
"detail": str(e)
|
||||||
|
}), 500
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return jsonify({"error": "compute_reroute invocation failed", "detail": str(e)}), 500
|
return jsonify({"error": "compute_reroute invocation failed", "detail": str(e)}), 500
|
||||||
@@ -294,46 +294,41 @@ def compute_reroute(
|
|||||||
"grid_shape": (n_lat, n_lon)
|
"grid_shape": (n_lat, n_lon)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def compute_index(lat: float,
|
||||||
|
lon: float,
|
||||||
|
max_risk: float = 10.0,
|
||||||
|
num_bins: int = 10,
|
||||||
|
risk_provider: Optional[Callable[[float, float], float]] = None) -> int:
|
||||||
|
"""
|
||||||
|
Computes and returns an index based on road risk and accident information.
|
||||||
|
|
||||||
def compute_index_and_reroute(lat: float,
|
Args:
|
||||||
lon: float,
|
lat: Latitude of the location.
|
||||||
api_key: Optional[str] = None,
|
lon: Longitude of the location.
|
||||||
roadrisk_url: Optional[str] = None,
|
max_risk: Maximum possible road risk score.
|
||||||
max_risk: float = 10.0,
|
num_bins: Number of bins to divide the risk range into.
|
||||||
num_bins: int = 10,
|
reroute_kwargs: Optional dictionary passed to reroute logic.
|
||||||
reroute_kwargs: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
|
risk_provider: Optional custom risk provider function.
|
||||||
"""
|
|
||||||
Get road risk, map to index (1..num_bins), and attempt reroute.
|
|
||||||
reroute_kwargs forwarded to compute_reroute.
|
|
||||||
api_key/roadrisk_url accepted for compatibility but ignored by Open-Meteo implementation.
|
|
||||||
"""
|
|
||||||
if reroute_kwargs is None:
|
|
||||||
reroute_kwargs = {}
|
|
||||||
|
|
||||||
data, features = fetch_road_risk(lat, lon, extra_params=reroute_kwargs, api_key=api_key, roadrisk_url=roadrisk_url)
|
Returns:
|
||||||
road_risk = float(features.get("road_risk_score", 0.0))
|
An integer index (1 to num_bins) based on computed road risk.
|
||||||
|
"""
|
||||||
|
# If a risk provider is not provided, use a default one
|
||||||
|
if risk_provider is None:
|
||||||
|
risk_provider = lambda lat, lon: get_risk_score(lat, lon)
|
||||||
|
|
||||||
accidents = features.get("accidents") or features.get("accident_count")
|
# Fetch road risk score using the provided risk provider
|
||||||
try:
|
road_risk = float(risk_provider(lat, lon))
|
||||||
if accidents is not None:
|
|
||||||
# fallback: map accident count to index if present
|
|
||||||
from .models import accidents_to_bucket # may not exist; wrapped in try
|
|
||||||
idx = accidents_to_bucket(int(accidents), max_count=20000, num_bins=num_bins)
|
|
||||||
else:
|
|
||||||
idx = risk_to_index(road_risk, max_risk=max_risk, num_bins=num_bins)
|
|
||||||
except Exception:
|
|
||||||
idx = risk_to_index(road_risk, max_risk=max_risk, num_bins=num_bins)
|
|
||||||
|
|
||||||
def _rp(lat_, lon_):
|
# Compute the index based on road risk score and the max_risk, num_bins parameters
|
||||||
return get_risk_score(lat_, lon_, api_key=api_key, roadrisk_url=roadrisk_url)
|
# The formula will divide the risk score into `num_bins` bins
|
||||||
|
# The index will be a number between 1 and num_bins based on the risk score
|
||||||
|
|
||||||
reroute_info = compute_reroute(lat, lon, risk_provider=_rp, **reroute_kwargs)
|
# Normalize the risk score to be between 0 and max_risk
|
||||||
return {
|
normalized_risk = min(road_risk, max_risk) / max_risk
|
||||||
"lat": lat,
|
|
||||||
"lon": lon,
|
# Compute the index based on the normalized risk score
|
||||||
"index": int(idx),
|
index = int(normalized_risk * num_bins)
|
||||||
"road_risk_score": road_risk,
|
|
||||||
"features": features,
|
# Ensure the index is within the expected range of 1 to num_bins
|
||||||
"reroute": reroute_info,
|
return max(1, min(index + 1, num_bins)) # Adding 1 because index is 0-based
|
||||||
"raw_roadrisk_response": data,
|
|
||||||
}
|
|
||||||
Reference in New Issue
Block a user