post request works
This commit is contained in:
240
roadcast/app.py
240
roadcast/app.py
@@ -1,75 +1,35 @@
|
|||||||
from flask import Flask, request, jsonify
|
|
||||||
from dotenv import load_dotenv
|
|
||||||
|
|
||||||
# Load environment variables from .env file
|
|
||||||
load_dotenv()
|
|
||||||
|
|
||||||
app = Flask(__name__)
|
|
||||||
import os
|
|
||||||
import threading
|
|
||||||
import json
|
|
||||||
|
|
||||||
# ML imports are lazy to avoid heavy imports on simple runs
|
|
||||||
|
|
||||||
|
|
||||||
@app.route('/get-data', methods=['GET'])
|
|
||||||
def get_data():
|
|
||||||
# Example GET request handler
|
|
||||||
data = {"message": "Hello from Flask!"}
|
|
||||||
return jsonify(data)
|
|
||||||
|
|
||||||
@app.route('/post-data', methods=['POST'])
|
|
||||||
def post_data():
|
|
||||||
# Example POST request handler
|
|
||||||
content = request.json
|
|
||||||
# Process content or call AI model here
|
|
||||||
response = {"you_sent": content}
|
|
||||||
return jsonify(response)
|
|
||||||
|
|
||||||
|
|
||||||
@app.route('/train', methods=['POST'])
|
|
||||||
def train_endpoint():
|
|
||||||
"""Trigger training. Expects JSON: {"data_root": "path/to/data", "epochs": 3}
|
|
||||||
Training runs in a background thread and saves model to model.pth in repo root.
|
|
||||||
"""
|
|
||||||
payload = request.json or {}
|
|
||||||
data_root = payload.get('data_root')
|
|
||||||
epochs = int(payload.get('epochs', 3))
|
|
||||||
if not data_root or not os.path.isdir(data_root):
|
|
||||||
return jsonify({"error": "data_root must be a valid directory path"}), 400
|
|
||||||
|
|
||||||
def _run_training():
|
|
||||||
from train import train
|
|
||||||
train(data_root, epochs=epochs)
|
|
||||||
|
|
||||||
t = threading.Thread(target=_run_training, daemon=True)
|
|
||||||
t.start()
|
|
||||||
return jsonify({"status": "training_started"})
|
|
||||||
|
|
||||||
|
|
||||||
@app.route('/predict', methods=['POST', 'GET'])
|
@app.route('/predict', methods=['POST', 'GET'])
|
||||||
def predict_endpoint():
|
def predict_endpoint():
|
||||||
"""Predict route between two points given source and destination with lat and lon.
|
"""Predict route between two points given source and destination with lat and lon.
|
||||||
|
|
||||||
|
|
||||||
Expectation:
|
Expectation:
|
||||||
- POST with JSON: {"source": {"lat": .., "lon": ..}, "destination": {"lat": .., "lon": ..}}
|
- POST with JSON: {"source": {"lat": .., "lon": ..}, "destination": {"lat": .., "lon": ..}}
|
||||||
- GET returns usage instructions for quick browser testing.
|
- GET returns usage instructions for quick browser testing.
|
||||||
"""
|
"""
|
||||||
|
example_payload = {
|
||||||
|
"source": {"lat": 38.9, "lon": -77.0},
|
||||||
|
"destination": {"lat": 38.95, "lon": -77.02}
|
||||||
|
}
|
||||||
|
info = "This endpoint expects a POST with JSON body."
|
||||||
|
note = (
|
||||||
|
"Use POST to receive a prediction. Example: curl -X POST -H 'Content-Type: application/json' "
|
||||||
|
"-d '{\"source\": {\"lat\": 38.9, \"lon\": -77.0}, \"destination\": {\"lat\": 38.95, \"lon\": -77.02}}' "
|
||||||
|
"http://127.0.0.1:5000/predict"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if request.method == 'GET':
|
if request.method == 'GET':
|
||||||
return jsonify({
|
return jsonify({"info": info, "example": example_payload, "note": note}), 200
|
||||||
"info": "This endpoint expects a POST with JSON body.",
|
|
||||||
"example": {
|
|
||||||
"source": {"lat": 38.9, "lon": -77.0},
|
|
||||||
"destination": {"lat": 38.95, "lon": -77.02}
|
|
||||||
},
|
|
||||||
"note": "Use POST to receive a prediction. Example: curl -X POST -H 'Content-Type: application/json' -d '{\"source\": {\"lat\": 38.9, \"lon\": -77.0}, \"destination\": {\"lat\": 38.95, \"lon\": -77.02}}' http://127.0.0.1:5000/predict"
|
|
||||||
}), 200
|
|
||||||
|
|
||||||
data = request.json or {}
|
data = request.json or {}
|
||||||
source = data.get('source')
|
source = data.get('source')
|
||||||
destination = data.get('destination')
|
destination = data.get('destination')
|
||||||
if not source or not destination:
|
if not source or not destination:
|
||||||
return jsonify({"error": "both 'source' and 'destination' fields are required"}), 400
|
return jsonify({"error": "both 'source' and 'destination' fields are required"}), 400
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
src_lat = float(source.get('lat'))
|
src_lat = float(source.get('lat'))
|
||||||
src_lon = float(source.get('lon'))
|
src_lon = float(source.get('lon'))
|
||||||
@@ -78,107 +38,93 @@ def predict_endpoint():
|
|||||||
except (TypeError, ValueError):
|
except (TypeError, ValueError):
|
||||||
return jsonify({"error": "invalid lat or lon values; must be numbers"}), 400
|
return jsonify({"error": "invalid lat or lon values; must be numbers"}), 400
|
||||||
|
|
||||||
|
|
||||||
# Ensure compute_reroute exists and is callable
|
# Ensure compute_reroute exists and is callable
|
||||||
try:
|
try:
|
||||||
from openweather_client import compute_reroute
|
from openmeteo_client import compute_reroute
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return jsonify({
|
return jsonify({
|
||||||
"error": "compute_reroute not found in openweather_client",
|
"error": "compute_reroute not found in openmeteo_client",
|
||||||
"detail": str(e),
|
"detail": str(e),
|
||||||
"hint": "Provide openweather_client.compute_reroute or implement a callable that accepts (src_lat, src_lon, dst_lat, dst_lon)"
|
"hint": "Provide openmeteo_client.compute_reroute "
|
||||||
|
"(Open-Meteo does not need an API key)"
|
||||||
}), 500
|
}), 500
|
||||||
|
|
||||||
|
|
||||||
if not callable(compute_reroute):
|
if not callable(compute_reroute):
|
||||||
return jsonify({"error": "openweather_client.compute_reroute is not callable"}), 500
|
return jsonify({"error": "openmeteo_client.compute_reroute is not callable"}), 500
|
||||||
|
|
||||||
# Call compute_reroute with fallback strategies
|
|
||||||
|
def _extract_index(res):
|
||||||
|
if res is None:
|
||||||
|
return None
|
||||||
|
if isinstance(res, (int, float)):
|
||||||
|
return int(res)
|
||||||
|
if isinstance(res, dict):
|
||||||
|
for k in ('index', 'idx', 'cluster', 'cluster_idx', 'label_index', 'label_idx'):
|
||||||
|
if k in res:
|
||||||
|
try:
|
||||||
|
return int(res[k])
|
||||||
|
except Exception:
|
||||||
|
return res[k]
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
# Call compute_reroute (Open-Meteo requires no API key)
|
||||||
try:
|
try:
|
||||||
|
result = compute_reroute(src_lat, src_lon, dst_lat, dst_lon)
|
||||||
|
called_with = "positional"
|
||||||
|
|
||||||
|
|
||||||
|
diagnostics = {"type": type(result).__name__}
|
||||||
try:
|
try:
|
||||||
result = compute_reroute(src_lat, src_lon, dst_lat, dst_lon)
|
diagnostics["repr"] = repr(result)[:1000]
|
||||||
except TypeError:
|
except Exception:
|
||||||
# fallback: single payload dict
|
diagnostics["repr"] = "<unrepr-able>"
|
||||||
payload = {'source': {'lat': src_lat, 'lon': src_lon}, 'destination': {'lat': dst_lat, 'lon': dst_lon}}
|
|
||||||
result = compute_reroute(payload)
|
|
||||||
|
|
||||||
# Normalize response
|
|
||||||
if isinstance(result, dict):
|
# Normalize return types
|
||||||
return jsonify(result)
|
if isinstance(result, (list, tuple)):
|
||||||
|
idx = None
|
||||||
|
for el in result:
|
||||||
|
idx = _extract_index(el)
|
||||||
|
if idx is not None:
|
||||||
|
break
|
||||||
|
prediction = {"items": list(result)}
|
||||||
|
index = idx
|
||||||
|
elif isinstance(result, dict):
|
||||||
|
index = _extract_index(result)
|
||||||
|
prediction = result
|
||||||
|
elif isinstance(result, (int, float, str)):
|
||||||
|
index = _extract_index(result)
|
||||||
|
prediction = {"value": result}
|
||||||
else:
|
else:
|
||||||
return jsonify({"result": result})
|
index = None
|
||||||
|
prediction = {"value": result}
|
||||||
|
|
||||||
|
|
||||||
|
response_payload = {
|
||||||
|
"index": index,
|
||||||
|
"prediction": prediction,
|
||||||
|
"called_with": called_with,
|
||||||
|
"diagnostics": diagnostics,
|
||||||
|
"example": example_payload,
|
||||||
|
"info": info,
|
||||||
|
"note": note
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# Add warning if no routing/index info found
|
||||||
|
expected_keys = ('route', 'path', 'distance', 'directions', 'index', 'idx', 'cluster')
|
||||||
|
if (not isinstance(prediction, dict) or not any(k in prediction for k in expected_keys)) and index is None:
|
||||||
|
response_payload["warning"] = (
|
||||||
|
"No routing/index information returned from compute_reroute. "
|
||||||
|
"See diagnostics for details."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
return jsonify(response_payload)
|
||||||
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return jsonify({"error": "compute_reroute invocation failed", "detail": str(e)}), 500
|
return jsonify({"error": "compute_reroute invocation failed", "detail": str(e)}), 500
|
||||||
|
|
||||||
|
|
||||||
@app.route('/')
|
|
||||||
def home():
|
|
||||||
return "<h1>Welcome to the Flask App</h1><p>Try /get-data or /health endpoints.</p>"
|
|
||||||
|
|
||||||
@app.route('/predict-roadrisk', methods=['POST'])
|
|
||||||
def predict_roadrisk():
|
|
||||||
"""Proxy endpoint to predict a roadrisk cluster from lat/lon/datetime.
|
|
||||||
|
|
||||||
Expects JSON body with: {"lat": 38.9, "lon": -77.0, "datetime": "2025-09-27T12:00:00", "roadrisk_url": "https://..."}
|
|
||||||
If roadrisk_url is not provided the endpoint will call OpenWeather OneCall (requires API key via OPENWEATHER_KEY env var).
|
|
||||||
"""
|
|
||||||
payload = request.json or {}
|
|
||||||
lat = payload.get('lat')
|
|
||||||
lon = payload.get('lon')
|
|
||||||
dt = payload.get('datetime')
|
|
||||||
street = payload.get('street', '')
|
|
||||||
roadrisk_url = payload.get('roadrisk_url')
|
|
||||||
# prefer explicit api_key in request, otherwise read from OPENWEATHER_API_KEY env var
|
|
||||||
api_key = payload.get('api_key') or os.environ.get('OPENWEATHER_API_KEY')
|
|
||||||
|
|
||||||
if lat is None or lon is None:
|
|
||||||
return jsonify({"error": "lat and lon are required fields"}), 400
|
|
||||||
|
|
||||||
try:
|
|
||||||
from openweather_inference import predict_from_openweather
|
|
||||||
# pass api_key (may be None) to the inference helper; helper will raise if a key is required
|
|
||||||
res = predict_from_openweather(
|
|
||||||
lat, lon,
|
|
||||||
dt_iso=dt,
|
|
||||||
street=street,
|
|
||||||
api_key=api_key,
|
|
||||||
train_csv=os.path.join(os.getcwd(), 'data.csv'),
|
|
||||||
preprocess_meta=None,
|
|
||||||
model_path=os.path.join(os.getcwd(), 'model.pth'),
|
|
||||||
centers_path=os.path.join(os.getcwd(), 'kmeans_centers_all.npz'),
|
|
||||||
roadrisk_url=roadrisk_url
|
|
||||||
)
|
|
||||||
return jsonify(res)
|
|
||||||
except Exception as e:
|
|
||||||
return jsonify({"error": str(e)}), 500
|
|
||||||
|
|
||||||
|
|
||||||
@app.route('/health', methods=['GET'])
|
|
||||||
def health():
|
|
||||||
"""Return status of loaded ML artifacts (model, centers, preprocess_meta)."""
|
|
||||||
try:
|
|
||||||
from openweather_inference import init_inference
|
|
||||||
status = init_inference()
|
|
||||||
return jsonify({'ok': True, 'artifacts': status})
|
|
||||||
except Exception as e:
|
|
||||||
return jsonify({'ok': False, 'error': str(e)}), 500
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
# eager load model/artifacts at startup (best-effort)
|
|
||||||
try:
|
|
||||||
from openweather_inference import init_inference
|
|
||||||
init_inference()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
app.run(debug=True)
|
|
||||||
|
|
||||||
# @app.route('/post-data', methods=['POST'])
|
|
||||||
# def post_data():
|
|
||||||
# content = request.json
|
|
||||||
# user_input = content.get('input')
|
|
||||||
|
|
||||||
# # Example: Simple echo AI (replace with real AI model code)
|
|
||||||
# ai_response = f"AI received: {user_input}"
|
|
||||||
|
|
||||||
# return jsonify({"response": ai_response})
|
|
||||||
# ai_response = f"AI received: {user_input}"
|
|
||||||
|
|
||||||
# return jsonify({"response": ai_response"})
|
|
||||||
339
roadcast/openmeteo_client.py
Normal file
339
roadcast/openmeteo_client.py
Normal file
@@ -0,0 +1,339 @@
|
|||||||
|
"""Open-Meteo historical weather client + simple road-risk heuristics.
|
||||||
|
|
||||||
|
Backwards-compatible API:
|
||||||
|
- fetch_weather(lat, lon, params=None, api_key=None)
|
||||||
|
- fetch_road_risk(lat, lon, extra_params=None, api_key=None, roadrisk_url=None)
|
||||||
|
- get_risk_score(lat, lon, **fetch_kwargs)
|
||||||
|
- compute_reroute(...)
|
||||||
|
- compute_index_and_reroute(...)
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
from typing import Tuple, Dict, Any, Optional, Callable, List
|
||||||
|
import requests
|
||||||
|
import heapq
|
||||||
|
import math
|
||||||
|
from datetime import date, timedelta
|
||||||
|
|
||||||
|
# Open-Meteo archive endpoint (no API key required)
|
||||||
|
BASE_ARCHIVE_URL = "https://archive-api.open-meteo.com/v1/archive"
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_weather(lat: float, lon: float, params: Optional[dict] = None, api_key: Optional[str] = None) -> dict:
|
||||||
|
"""Fetch historical weather from Open-Meteo archive API.
|
||||||
|
|
||||||
|
Params may include 'start_date', 'end_date' (YYYY-MM-DD) and 'hourly' (comma-separated vars).
|
||||||
|
Defaults to yesterday..today and hourly variables useful for road risk.
|
||||||
|
(api_key parameter is accepted for compatibility but ignored.)
|
||||||
|
"""
|
||||||
|
if params is None:
|
||||||
|
params = {}
|
||||||
|
|
||||||
|
today = date.today()
|
||||||
|
start = params.get("start_date", (today - timedelta(days=1)).isoformat())
|
||||||
|
end = params.get("end_date", today.isoformat())
|
||||||
|
|
||||||
|
hourly = params.get(
|
||||||
|
"hourly",
|
||||||
|
",".join(["temperature_2m", "relativehumidity_2m", "windspeed_10m", "precipitation", "weathercode"])
|
||||||
|
)
|
||||||
|
|
||||||
|
query = {
|
||||||
|
"latitude": lat,
|
||||||
|
"longitude": lon,
|
||||||
|
"start_date": start,
|
||||||
|
"end_date": end,
|
||||||
|
"hourly": hourly,
|
||||||
|
"timezone": params.get("timezone", "UTC"),
|
||||||
|
}
|
||||||
|
|
||||||
|
resp = requests.get(BASE_ARCHIVE_URL, params=query, timeout=15)
|
||||||
|
resp.raise_for_status()
|
||||||
|
return resp.json()
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_road_risk(
|
||||||
|
lat: float,
|
||||||
|
lon: float,
|
||||||
|
extra_params: Optional[dict] = None,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
roadrisk_url: Optional[str] = None
|
||||||
|
) -> Tuple[dict, Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Compute a simple road risk estimation using Open-Meteo historical weather.
|
||||||
|
|
||||||
|
Returns (raw_data, features) where features includes 'road_risk_score' (float).
|
||||||
|
api_key and roadrisk_url are accepted for backward compatibility but ignored.
|
||||||
|
"""
|
||||||
|
params = {}
|
||||||
|
if extra_params:
|
||||||
|
params.update(extra_params)
|
||||||
|
|
||||||
|
# fetch weather via Open-Meteo archive
|
||||||
|
try:
|
||||||
|
data = fetch_weather(lat, lon, params=params)
|
||||||
|
except Exception as e:
|
||||||
|
features: Dict[str, Any] = {"road_risk_score": 0.0, "error": str(e)}
|
||||||
|
return {}, features
|
||||||
|
|
||||||
|
hourly = data.get("hourly", {}) if isinstance(data, dict) else {}
|
||||||
|
|
||||||
|
def _arr_mean(key):
|
||||||
|
arr = hourly.get(key)
|
||||||
|
if isinstance(arr, list) and arr:
|
||||||
|
valid = [float(x) for x in arr if x is not None]
|
||||||
|
return sum(valid) / max(1, len(valid)) if valid else None
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _arr_max(key):
|
||||||
|
arr = hourly.get(key)
|
||||||
|
if isinstance(arr, list) and arr:
|
||||||
|
valid = [float(x) for x in arr if x is not None]
|
||||||
|
return max(valid) if valid else None
|
||||||
|
return None
|
||||||
|
|
||||||
|
precip_mean = _arr_mean("precipitation")
|
||||||
|
wind_mean = _arr_mean("windspeed_10m")
|
||||||
|
wind_max = _arr_max("windspeed_10m")
|
||||||
|
temp_mean = _arr_mean("temperature_2m")
|
||||||
|
humidity_mean = _arr_mean("relativehumidity_2m")
|
||||||
|
weathercodes = hourly.get("weathercode", [])
|
||||||
|
|
||||||
|
# heuristic risk scoring:
|
||||||
|
risk = 0.0
|
||||||
|
if precip_mean is not None:
|
||||||
|
risk += float(precip_mean) * 2.0
|
||||||
|
if wind_mean is not None:
|
||||||
|
risk += float(wind_mean) * 0.1
|
||||||
|
if wind_max is not None and float(wind_max) > 15.0:
|
||||||
|
risk += 1.0
|
||||||
|
if humidity_mean is not None and float(humidity_mean) > 85.0:
|
||||||
|
risk += 0.5
|
||||||
|
try:
|
||||||
|
# sample Open-Meteo weather codes that indicate precipitation/snow
|
||||||
|
if any(int(wc) in (51, 61, 63, 65, 80, 81, 82, 71, 73, 75, 85, 86) for wc in weathercodes if wc is not None):
|
||||||
|
risk += 1.0
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if not math.isfinite(risk):
|
||||||
|
risk = 0.0
|
||||||
|
if risk < 0:
|
||||||
|
risk = 0.0
|
||||||
|
|
||||||
|
features: Dict[str, Any] = {
|
||||||
|
"precipitation_mean": float(precip_mean) if precip_mean is not None else None,
|
||||||
|
"wind_mean": float(wind_mean) if wind_mean is not None else None,
|
||||||
|
"wind_max": float(wind_max) if wind_max is not None else None,
|
||||||
|
"temp_mean": float(temp_mean) if temp_mean is not None else None,
|
||||||
|
"humidity_mean": float(humidity_mean) if humidity_mean is not None else None,
|
||||||
|
"road_risk_score": float(risk),
|
||||||
|
}
|
||||||
|
|
||||||
|
# include some raw metadata if present
|
||||||
|
if "latitude" in data:
|
||||||
|
features["latitude"] = data.get("latitude")
|
||||||
|
if "longitude" in data:
|
||||||
|
features["longitude"] = data.get("longitude")
|
||||||
|
if "generationtime_ms" in data:
|
||||||
|
features["generationtime_ms"] = data.get("generationtime_ms")
|
||||||
|
|
||||||
|
return data, features
|
||||||
|
|
||||||
|
|
||||||
|
def _haversine_km(a_lat: float, a_lon: float, b_lat: float, b_lon: float) -> float:
|
||||||
|
# returns distance in kilometers
|
||||||
|
R = 6371.0
|
||||||
|
lat1, lon1, lat2, lon2 = map(math.radians, (a_lat, a_lon, b_lat, b_lon))
|
||||||
|
dlat = lat2 - lat1
|
||||||
|
dlon = lon2 - lon1
|
||||||
|
h = math.sin(dlat / 2) ** 2 + math.cos(lat1) * math.cos(lat2) * math.sin(dlon / 2) ** 2
|
||||||
|
return 2 * R * math.asin(min(1.0, math.sqrt(h)))
|
||||||
|
|
||||||
|
|
||||||
|
def risk_to_index(risk_score: float, max_risk: float = 10.0, num_bins: int = 10) -> int:
|
||||||
|
"""
|
||||||
|
Map a numeric risk_score to an integer index 1..num_bins (higher => more risky).
|
||||||
|
"""
|
||||||
|
if risk_score is None:
|
||||||
|
return 1
|
||||||
|
r = float(risk_score)
|
||||||
|
if r <= 0:
|
||||||
|
return 1
|
||||||
|
if r >= max_risk:
|
||||||
|
return num_bins
|
||||||
|
bin_width = max_risk / float(num_bins)
|
||||||
|
return int(r // bin_width) + 1
|
||||||
|
|
||||||
|
|
||||||
|
def get_risk_score(lat: float, lon: float, **fetch_kwargs) -> float:
|
||||||
|
"""Wrapper: calls fetch_road_risk and returns features['road_risk_score'] (float)."""
|
||||||
|
_, features = fetch_road_risk(lat, lon, extra_params=fetch_kwargs)
|
||||||
|
return float(features.get("road_risk_score", 0.0))
|
||||||
|
|
||||||
|
|
||||||
|
def compute_reroute(
|
||||||
|
start_lat: float,
|
||||||
|
start_lon: float,
|
||||||
|
end_lat: float,
|
||||||
|
end_lon: float,
|
||||||
|
risk_provider: Callable[[float, float], float] = None,
|
||||||
|
n_lat: int = 20,
|
||||||
|
n_lon: int = 20,
|
||||||
|
distance_weight: float = 0.1,
|
||||||
|
max_calls: Optional[int] = None
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Plan a path from (start_lat, start_lon) to (end_lat, end_lon) that avoids risky areas.
|
||||||
|
Uses Dijkstra's algorithm over a lat/lon grid with cost = avg risk + distance_weight * distance.
|
||||||
|
"""
|
||||||
|
if risk_provider is None:
|
||||||
|
risk_provider = lambda lat, lon: get_risk_score(lat, lon)
|
||||||
|
|
||||||
|
min_lat = min(start_lat, end_lat)
|
||||||
|
max_lat = max(start_lat, end_lat)
|
||||||
|
min_lon = min(start_lon, end_lon)
|
||||||
|
max_lon = max(start_lon, end_lon)
|
||||||
|
|
||||||
|
lat_padding = (max_lat - min_lat) * 0.2
|
||||||
|
lon_padding = (max_lon - min_lon) * 0.2
|
||||||
|
min_lat -= lat_padding
|
||||||
|
max_lat += lat_padding
|
||||||
|
min_lon -= lon_padding
|
||||||
|
max_lon += lon_padding
|
||||||
|
|
||||||
|
lat_step = (max_lat - min_lat) / (n_lat - 1) if n_lat > 1 else 0.0
|
||||||
|
lon_step = (max_lon - min_lon) / (n_lon - 1) if n_lon > 1 else 0.0
|
||||||
|
|
||||||
|
coords = []
|
||||||
|
for i in range(n_lat):
|
||||||
|
for j in range(n_lon):
|
||||||
|
coords.append((min_lat + i * lat_step, min_lon + j * lon_step))
|
||||||
|
|
||||||
|
risks = []
|
||||||
|
calls = 0
|
||||||
|
for lat, lon in coords:
|
||||||
|
if max_calls is not None and calls >= max_calls:
|
||||||
|
risks.append(float('inf'))
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
risk = risk_provider(lat, lon)
|
||||||
|
except Exception:
|
||||||
|
risk = float('inf')
|
||||||
|
risks.append(float(risk))
|
||||||
|
calls += 1
|
||||||
|
|
||||||
|
def idx(i, j):
|
||||||
|
return i * n_lon + j
|
||||||
|
|
||||||
|
def find_closest(lat, lon):
|
||||||
|
i = round((lat - min_lat) / (lat_step if lat_step != 0 else 1e-9))
|
||||||
|
j = round((lon - min_lon) / (lon_step if lon_step != 0 else 1e-9))
|
||||||
|
i = max(0, min(n_lat - 1, i))
|
||||||
|
j = max(0, min(n_lon - 1, j))
|
||||||
|
return idx(i, j)
|
||||||
|
|
||||||
|
start_idx = find_closest(start_lat, start_lon)
|
||||||
|
end_idx = find_closest(end_lat, end_lon)
|
||||||
|
|
||||||
|
N = len(coords)
|
||||||
|
dist = [math.inf] * N
|
||||||
|
prev = [None] * N
|
||||||
|
dist[start_idx] = 0.0
|
||||||
|
pq = [(0.0, start_idx)]
|
||||||
|
|
||||||
|
while pq:
|
||||||
|
cost, u = heapq.heappop(pq)
|
||||||
|
if cost > dist[u]:
|
||||||
|
continue
|
||||||
|
if u == end_idx:
|
||||||
|
break
|
||||||
|
|
||||||
|
ui, uj = u // n_lon, u % n_lon
|
||||||
|
for di, dj in ((1, 0), (-1, 0), (0, 1), (0, -1)):
|
||||||
|
vi, vj = ui + di, uj + dj
|
||||||
|
if 0 <= vi < n_lat and 0 <= vj < n_lon:
|
||||||
|
v = idx(vi, vj)
|
||||||
|
if math.isinf(risks[v]) or math.isinf(risks[u]):
|
||||||
|
continue
|
||||||
|
lat_u, lon_u = coords[u]
|
||||||
|
lat_v, lon_v = coords[v]
|
||||||
|
d_km = _haversine_km(lat_u, lon_u, lat_v, lon_v)
|
||||||
|
edge_cost = (risks[u] + risks[v]) / 2 + distance_weight * d_km
|
||||||
|
new_cost = cost + edge_cost
|
||||||
|
if new_cost < dist[v]:
|
||||||
|
dist[v] = new_cost
|
||||||
|
prev[v] = u
|
||||||
|
heapq.heappush(pq, (new_cost, v))
|
||||||
|
|
||||||
|
if math.isinf(dist[end_idx]):
|
||||||
|
return {
|
||||||
|
"reroute_needed": False,
|
||||||
|
"reason": "no_path_found",
|
||||||
|
"start_coord": (start_lat, start_lon),
|
||||||
|
"end_coord": (end_lat, end_lon),
|
||||||
|
"calls_made": calls
|
||||||
|
}
|
||||||
|
|
||||||
|
path_indices = []
|
||||||
|
u = end_idx
|
||||||
|
while u is not None:
|
||||||
|
path_indices.append(u)
|
||||||
|
u = prev[u]
|
||||||
|
path_indices.reverse()
|
||||||
|
|
||||||
|
path_coords = [coords[i] for i in path_indices]
|
||||||
|
return {
|
||||||
|
"reroute_needed": True,
|
||||||
|
"start_coord": (start_lat, start_lon),
|
||||||
|
"end_coord": (end_lat, end_lon),
|
||||||
|
"path": path_coords,
|
||||||
|
"total_cost": dist[end_idx],
|
||||||
|
"start_risk": risks[start_idx],
|
||||||
|
"end_risk": risks[end_idx],
|
||||||
|
"calls_made": calls,
|
||||||
|
"grid_shape": (n_lat, n_lon)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def compute_index_and_reroute(lat: float,
|
||||||
|
lon: float,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
roadrisk_url: Optional[str] = None,
|
||||||
|
max_risk: float = 10.0,
|
||||||
|
num_bins: int = 10,
|
||||||
|
reroute_kwargs: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Get road risk, map to index (1..num_bins), and attempt reroute.
|
||||||
|
reroute_kwargs forwarded to compute_reroute.
|
||||||
|
api_key/roadrisk_url accepted for compatibility but ignored by Open-Meteo implementation.
|
||||||
|
"""
|
||||||
|
if reroute_kwargs is None:
|
||||||
|
reroute_kwargs = {}
|
||||||
|
|
||||||
|
data, features = fetch_road_risk(lat, lon, extra_params=reroute_kwargs, api_key=api_key, roadrisk_url=roadrisk_url)
|
||||||
|
road_risk = float(features.get("road_risk_score", 0.0))
|
||||||
|
|
||||||
|
accidents = features.get("accidents") or features.get("accident_count")
|
||||||
|
try:
|
||||||
|
if accidents is not None:
|
||||||
|
# fallback: map accident count to index if present
|
||||||
|
from .models import accidents_to_bucket # may not exist; wrapped in try
|
||||||
|
idx = accidents_to_bucket(int(accidents), max_count=20000, num_bins=num_bins)
|
||||||
|
else:
|
||||||
|
idx = risk_to_index(road_risk, max_risk=max_risk, num_bins=num_bins)
|
||||||
|
except Exception:
|
||||||
|
idx = risk_to_index(road_risk, max_risk=max_risk, num_bins=num_bins)
|
||||||
|
|
||||||
|
def _rp(lat_, lon_):
|
||||||
|
return get_risk_score(lat_, lon_, api_key=api_key, roadrisk_url=roadrisk_url)
|
||||||
|
|
||||||
|
reroute_info = compute_reroute(lat, lon, risk_provider=_rp, **reroute_kwargs)
|
||||||
|
return {
|
||||||
|
"lat": lat,
|
||||||
|
"lon": lon,
|
||||||
|
"index": int(idx),
|
||||||
|
"road_risk_score": road_risk,
|
||||||
|
"features": features,
|
||||||
|
"reroute": reroute_info,
|
||||||
|
"raw_roadrisk_response": data,
|
||||||
|
}
|
||||||
339
roadcast/openmeteo_inference.py
Normal file
339
roadcast/openmeteo_inference.py
Normal file
@@ -0,0 +1,339 @@
|
|||||||
|
"""
|
||||||
|
Fetch OpenWeather data for a coordinate/time and run the trained MLP to predict the k-means cluster label.
|
||||||
|
|
||||||
|
Usage examples:
|
||||||
|
# with training CSV provided to compute preprocessing stats:
|
||||||
|
python openweather_inference.py --lat 38.9 --lon -77.0 --datetime "2025-09-27T12:00:00" --train-csv data.csv --model model.pth --centers kmeans_centers_all.npz --api-key $OPENWEATHER_KEY
|
||||||
|
|
||||||
|
# with precomputed preprocess meta (saved from training):
|
||||||
|
python openweather_inference.py --lat 38.9 --lon -77.0 --datetime "2025-09-27T12:00:00" --preprocess-meta preprocess_meta.npz --model model.pth --centers kmeans_centers_all.npz --api-key $OPENWEATHER_KEY
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
- The script uses the same feature-engineering helpers in `data.py` so the model sees identical inputs.
|
||||||
|
- You must either provide `--train-csv` (to compute feature columns & means/stds) or `--preprocess-meta` previously saved.
|
||||||
|
- Provide the OpenWeather API key via --api-key or the OPENWEATHER_KEY environment variable.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
from datetime import datetime
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
# reuse helpers from your repo
|
||||||
|
from data import _add_date_features, _add_latlon_bins, _add_hashed_street, CSVDataset
|
||||||
|
from inference import load_model
|
||||||
|
|
||||||
|
# module-level caches to avoid reloading heavy artifacts per request
|
||||||
|
_CACHED_MODEL = None
|
||||||
|
_CACHED_IDX_TO_CLASS = None
|
||||||
|
_CACHED_CENTERS = None
|
||||||
|
_CACHED_PREPROCESS_META = None
|
||||||
|
|
||||||
|
|
||||||
|
OW_BASE = 'https://api.openweathermap.org/data/2.5/onecall'
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_openmeteo(lat, lon, api_key, dt_iso=None):
|
||||||
|
"""Fetch weather from OpenWeather One Call API for given lat/lon. If dt_iso provided, we fetch current+hourly and pick closest timestamp."""
|
||||||
|
try:
|
||||||
|
import requests
|
||||||
|
except Exception:
|
||||||
|
raise RuntimeError('requests library is required to fetch OpenWeather data')
|
||||||
|
params = {
|
||||||
|
'lat': float(lat),
|
||||||
|
'lon': float(lon),
|
||||||
|
'appid': api_key,
|
||||||
|
'units': 'metric',
|
||||||
|
'exclude': 'minutely,alerts'
|
||||||
|
}
|
||||||
|
r = requests.get(OW_BASE, params=params, timeout=10)
|
||||||
|
r.raise_for_status()
|
||||||
|
payload = r.json()
|
||||||
|
# if dt_iso provided, find nearest hourly data point
|
||||||
|
if dt_iso:
|
||||||
|
try:
|
||||||
|
target = pd.to_datetime(dt_iso)
|
||||||
|
except Exception:
|
||||||
|
target = None
|
||||||
|
best = None
|
||||||
|
if 'hourly' in payload and target is not None:
|
||||||
|
hours = payload['hourly']
|
||||||
|
best = min(hours, key=lambda h: abs(pd.to_datetime(h['dt'], unit='s') - target))
|
||||||
|
# convert keys to a flat dict with prefix 'ow_'
|
||||||
|
d = {
|
||||||
|
'ow_temp': best.get('temp'),
|
||||||
|
'ow_feels_like': best.get('feels_like'),
|
||||||
|
'ow_pressure': best.get('pressure'),
|
||||||
|
'ow_humidity': best.get('humidity'),
|
||||||
|
'ow_wind_speed': best.get('wind_speed'),
|
||||||
|
'ow_clouds': best.get('clouds'),
|
||||||
|
'ow_pop': best.get('pop'),
|
||||||
|
}
|
||||||
|
return d
|
||||||
|
# fallback: use current
|
||||||
|
cur = payload.get('current', {})
|
||||||
|
d = {
|
||||||
|
'ow_temp': cur.get('temp'),
|
||||||
|
'ow_feels_like': cur.get('feels_like'),
|
||||||
|
'ow_pressure': cur.get('pressure'),
|
||||||
|
'ow_humidity': cur.get('humidity'),
|
||||||
|
'ow_wind_speed': cur.get('wind_speed'),
|
||||||
|
'ow_clouds': cur.get('clouds'),
|
||||||
|
'ow_pop': None,
|
||||||
|
}
|
||||||
|
return d
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_roadrisk(roadrisk_url, api_key=None):
|
||||||
|
"""Fetch the RoadRisk endpoint (expects JSON). If `api_key` is provided, we'll attach it as a query param if the URL has no key.
|
||||||
|
|
||||||
|
We flatten top-level numeric fields into `rr_*` keys for the feature row.
|
||||||
|
"""
|
||||||
|
# if api_key provided and url does not contain appid, append it
|
||||||
|
try:
|
||||||
|
import requests
|
||||||
|
except Exception:
|
||||||
|
raise RuntimeError('requests library is required to fetch RoadRisk data')
|
||||||
|
url = roadrisk_url
|
||||||
|
if api_key and 'appid=' not in roadrisk_url:
|
||||||
|
sep = '&' if '?' in roadrisk_url else '?'
|
||||||
|
url = f"{roadrisk_url}{sep}appid={api_key}"
|
||||||
|
|
||||||
|
r = requests.get(url, timeout=10)
|
||||||
|
r.raise_for_status()
|
||||||
|
payload = r.json()
|
||||||
|
# flatten numeric top-level fields
|
||||||
|
out = {}
|
||||||
|
if isinstance(payload, dict):
|
||||||
|
for k, v in payload.items():
|
||||||
|
if isinstance(v, (int, float)):
|
||||||
|
out[f'rr_{k}'] = v
|
||||||
|
# if nested objects contain simple numeric fields, pull them too (one level deep)
|
||||||
|
elif isinstance(v, dict):
|
||||||
|
for kk, vv in v.items():
|
||||||
|
if isinstance(vv, (int, float)):
|
||||||
|
out[f'rr_{k}_{kk}'] = vv
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def build_row(lat, lon, dt_iso=None, street=None, extra_weather=None):
|
||||||
|
"""Construct a single-row DataFrame with columns expected by the training pipeline.
|
||||||
|
|
||||||
|
It intentionally uses column names the original `data.py` looked for (REPORTDATE, LATITUDE, LONGITUDE, ADDRESS, etc.).
|
||||||
|
"""
|
||||||
|
row = {}
|
||||||
|
# date column matching common names
|
||||||
|
row['REPORTDATE'] = dt_iso if dt_iso else datetime.utcnow().isoformat()
|
||||||
|
row['LATITUDE'] = lat
|
||||||
|
row['LONGITUDE'] = lon
|
||||||
|
row['ADDRESS'] = street if street else ''
|
||||||
|
# include some injury/fatality placeholders that the label generator expects
|
||||||
|
row['INJURIES'] = 0
|
||||||
|
row['FATALITIES'] = 0
|
||||||
|
# include weather features returned by OpenWeather (prefixed 'ow_')
|
||||||
|
if extra_weather:
|
||||||
|
for k, v in extra_weather.items():
|
||||||
|
row[k] = v
|
||||||
|
return pd.DataFrame([row])
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_features(df_row, train_csv=None, preprocess_meta=None, feature_engineer=True, lat_lon_bins=20):
|
||||||
|
"""Given a one-row DataFrame, apply same feature engineering and standardization as training.
|
||||||
|
|
||||||
|
If preprocess_meta is provided (npz), use it. Otherwise train_csv must be provided to compute stats.
|
||||||
|
Returns a torch.FloatTensor of shape (1, input_dim) and the feature_columns list.
|
||||||
|
"""
|
||||||
|
# apply feature engineering helpers
|
||||||
|
if feature_engineer:
|
||||||
|
try:
|
||||||
|
_add_date_features(df_row)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
_add_latlon_bins(df_row, bins=lat_lon_bins)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
_add_hashed_street(df_row)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# if meta provided, load feature_columns, means, stds
|
||||||
|
if preprocess_meta and os.path.exists(preprocess_meta):
|
||||||
|
meta = np.load(preprocess_meta, allow_pickle=True)
|
||||||
|
feature_columns = meta['feature_columns'].tolist()
|
||||||
|
means = meta['means']
|
||||||
|
stds = meta['stds']
|
||||||
|
else:
|
||||||
|
if not train_csv:
|
||||||
|
raise ValueError('Either preprocess_meta or train_csv must be provided to derive feature stats')
|
||||||
|
# instantiate a CSVDataset on train_csv (feature_engineer True) to reuse its preprocessing
|
||||||
|
ds = CSVDataset(train_csv, feature_columns=None, label_column='label', generate_labels=True, n_buckets=10, label_method='kmeans', label_store=None, feature_engineer=feature_engineer, lat_lon_bins=lat_lon_bins, nrows=None)
|
||||||
|
feature_columns = ds.feature_columns
|
||||||
|
means = ds.feature_means
|
||||||
|
stds = ds.feature_stds
|
||||||
|
# save meta for reuse
|
||||||
|
np.savez_compressed('preprocess_meta.npz', feature_columns=np.array(feature_columns, dtype=object), means=means, stds=stds)
|
||||||
|
print('Saved preprocess_meta.npz')
|
||||||
|
|
||||||
|
# ensure all feature columns exist in df_row
|
||||||
|
for c in feature_columns:
|
||||||
|
if c not in df_row.columns:
|
||||||
|
df_row[c] = 0
|
||||||
|
|
||||||
|
# coerce and fill using means
|
||||||
|
features_df = df_row[feature_columns].apply(lambda c: pd.to_numeric(c, errors='coerce'))
|
||||||
|
features_df = features_df.fillna(pd.Series(means, index=feature_columns)).fillna(0.0)
|
||||||
|
# standardize
|
||||||
|
features_np = (features_df.values - means) / (stds + 1e-6)
|
||||||
|
import torch
|
||||||
|
return torch.tensor(features_np, dtype=torch.float32), feature_columns
|
||||||
|
|
||||||
|
|
||||||
|
def predict_from_openmeteo(lat, lon, dt_iso=None, street=None, api_key=None, train_csv=None, preprocess_meta=None, model_path='model.pth', centers_path='kmeans_centers_all.npz', roadrisk_url=None):
|
||||||
|
api_key = api_key or os.environ.get('OPENWEATHER_KEY')
|
||||||
|
if api_key is None:
|
||||||
|
raise ValueError('OpenWeather API key required via --api-key or OPENWEATHER_KEY env var')
|
||||||
|
|
||||||
|
# gather weather/road-risk features
|
||||||
|
weather = {}
|
||||||
|
if roadrisk_url:
|
||||||
|
try:
|
||||||
|
rr = fetch_roadrisk(roadrisk_url, api_key=api_key)
|
||||||
|
weather.update(rr)
|
||||||
|
except Exception as e:
|
||||||
|
print('Warning: failed to fetch roadrisk URL:', e)
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
ow = fetch_openmeteo(lat, lon, api_key, dt_iso=dt_iso)
|
||||||
|
weather.update(ow)
|
||||||
|
except Exception as e:
|
||||||
|
print('Warning: failed to fetch openweather:', e)
|
||||||
|
|
||||||
|
df_row = build_row(lat, lon, dt_iso=dt_iso, street=street, extra_weather=weather)
|
||||||
|
x_tensor, feature_columns = prepare_features(df_row, train_csv=train_csv, preprocess_meta=preprocess_meta)
|
||||||
|
|
||||||
|
# load model (infer num_classes from centers file if possible)
|
||||||
|
global _CACHED_MODEL, _CACHED_IDX_TO_CLASS, _CACHED_CENTERS, _CACHED_PREPROCESS_META
|
||||||
|
|
||||||
|
# ensure we have preprocess_meta available (prefer supplied path, otherwise fallback to saved file)
|
||||||
|
if preprocess_meta is None:
|
||||||
|
candidate = os.path.join(os.getcwd(), 'preprocess_meta.npz')
|
||||||
|
if os.path.exists(candidate):
|
||||||
|
preprocess_meta = candidate
|
||||||
|
|
||||||
|
# load centers (cache across requests)
|
||||||
|
if _CACHED_CENTERS is None:
|
||||||
|
if centers_path and os.path.exists(centers_path):
|
||||||
|
try:
|
||||||
|
npz = np.load(centers_path)
|
||||||
|
_CACHED_CENTERS = npz['centers']
|
||||||
|
except Exception:
|
||||||
|
_CACHED_CENTERS = None
|
||||||
|
else:
|
||||||
|
_CACHED_CENTERS = None
|
||||||
|
|
||||||
|
num_classes = _CACHED_CENTERS.shape[0] if _CACHED_CENTERS is not None else 10
|
||||||
|
|
||||||
|
# load model once and cache it
|
||||||
|
if _CACHED_MODEL is None:
|
||||||
|
try:
|
||||||
|
_CACHED_MODEL, _CACHED_IDX_TO_CLASS = load_model(model_path, device=None, in_channels=3, num_classes=num_classes)
|
||||||
|
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||||
|
_CACHED_MODEL.to(device)
|
||||||
|
except Exception as e:
|
||||||
|
raise
|
||||||
|
model = _CACHED_MODEL
|
||||||
|
idx_to_class = _CACHED_IDX_TO_CLASS
|
||||||
|
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||||
|
x_tensor = x_tensor.to(device)
|
||||||
|
with torch.no_grad():
|
||||||
|
logits = model(x_tensor)
|
||||||
|
probs = F.softmax(logits, dim=1).cpu().numpy()[0]
|
||||||
|
pred_idx = int(probs.argmax())
|
||||||
|
confidence = float(probs.max())
|
||||||
|
|
||||||
|
# optionally provide cluster centroid info
|
||||||
|
centroid = _CACHED_CENTERS[pred_idx] if _CACHED_CENTERS is not None else None
|
||||||
|
|
||||||
|
return {
|
||||||
|
'pred_cluster': int(pred_idx),
|
||||||
|
'confidence': confidence,
|
||||||
|
'probabilities': probs.tolist(),
|
||||||
|
'centroid': centroid.tolist() if centroid is not None else None,
|
||||||
|
'feature_columns': feature_columns,
|
||||||
|
'used_preprocess_meta': preprocess_meta
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def init_inference(model_path='model.pth', centers_path='kmeans_centers_all.npz', preprocess_meta=None):
|
||||||
|
"""Eagerly load model, centers, and preprocess_meta into module-level caches.
|
||||||
|
|
||||||
|
This is intended to be called at app startup to surface load errors early and avoid
|
||||||
|
per-request disk IO. The function is best-effort and will print warnings if artifacts
|
||||||
|
are missing.
|
||||||
|
"""
|
||||||
|
global _CACHED_MODEL, _CACHED_IDX_TO_CLASS, _CACHED_CENTERS, _CACHED_PREPROCESS_META
|
||||||
|
|
||||||
|
# prefer existing saved preprocess_meta if not explicitly provided
|
||||||
|
if preprocess_meta is None:
|
||||||
|
candidate = os.path.join(os.getcwd(), 'preprocess_meta.npz')
|
||||||
|
if os.path.exists(candidate):
|
||||||
|
preprocess_meta = candidate
|
||||||
|
|
||||||
|
_CACHED_PREPROCESS_META = preprocess_meta
|
||||||
|
|
||||||
|
# load centers
|
||||||
|
if _CACHED_CENTERS is None:
|
||||||
|
if centers_path and os.path.exists(centers_path):
|
||||||
|
try:
|
||||||
|
npz = np.load(centers_path)
|
||||||
|
_CACHED_CENTERS = npz['centers']
|
||||||
|
print(f'Loaded centers from {centers_path}')
|
||||||
|
except Exception as e:
|
||||||
|
print('Warning: failed to load centers:', e)
|
||||||
|
_CACHED_CENTERS = None
|
||||||
|
else:
|
||||||
|
print('No centers file found at', centers_path)
|
||||||
|
_CACHED_CENTERS = None
|
||||||
|
|
||||||
|
num_classes = _CACHED_CENTERS.shape[0] if _CACHED_CENTERS is not None else 10
|
||||||
|
|
||||||
|
# load model
|
||||||
|
if _CACHED_MODEL is None:
|
||||||
|
try:
|
||||||
|
_CACHED_MODEL, _CACHED_IDX_TO_CLASS = load_model(model_path, device=None, in_channels=3, num_classes=num_classes)
|
||||||
|
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||||
|
_CACHED_MODEL.to(device)
|
||||||
|
print(f'Loaded model from {model_path}')
|
||||||
|
except Exception as e:
|
||||||
|
print('Warning: failed to load model:', e)
|
||||||
|
_CACHED_MODEL = None
|
||||||
|
|
||||||
|
return {
|
||||||
|
'model_loaded': _CACHED_MODEL is not None,
|
||||||
|
'centers_loaded': _CACHED_CENTERS is not None,
|
||||||
|
'preprocess_meta': _CACHED_PREPROCESS_META
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--lat', type=float, required=True)
|
||||||
|
parser.add_argument('--lon', type=float, required=True)
|
||||||
|
parser.add_argument('--datetime', default=None, help='ISO datetime string to query hourly weather (optional)')
|
||||||
|
parser.add_argument('--street', default='')
|
||||||
|
parser.add_argument('--api-key', default=None, help='OpenWeather API key or use OPENWEATHER_KEY env var')
|
||||||
|
parser.add_argument('--train-csv', default=None, help='Path to training CSV to compute preprocessing stats (optional if --preprocess-meta provided)')
|
||||||
|
parser.add_argument('--preprocess-meta', default=None, help='Path to precomputed preprocess_meta.npz (optional)')
|
||||||
|
parser.add_argument('--model', default='model.pth')
|
||||||
|
parser.add_argument('--centers', default='kmeans_centers_all.npz')
|
||||||
|
parser.add_argument('--roadrisk-url', default=None, help='Optional custom RoadRisk API URL (if provided, will be queried instead of OneCall)')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
out = predict_from_openmeteo(args.lat, args.lon, dt_iso=args.datetime, street=args.street, api_key=args.api_key, train_csv=args.train_csv, preprocess_meta=args.preprocess_meta, model_path=args.model, centers_path=args.centers, roadrisk_url=args.roadrisk_url)
|
||||||
|
print(json.dumps(out, indent=2))
|
||||||
Reference in New Issue
Block a user