From 42e4488d45e31704183e3df9f7625ae40344247f Mon Sep 17 00:00:00 2001 From: samarthjain2023 Date: Sun, 28 Sep 2025 00:47:45 -0400 Subject: [PATCH] post request works --- roadcast/app.py | 240 +++++++++------------- roadcast/openmeteo_client.py | 339 ++++++++++++++++++++++++++++++++ roadcast/openmeteo_inference.py | 339 ++++++++++++++++++++++++++++++++ 3 files changed, 771 insertions(+), 147 deletions(-) create mode 100644 roadcast/openmeteo_client.py create mode 100644 roadcast/openmeteo_inference.py diff --git a/roadcast/app.py b/roadcast/app.py index 9b991ec..dd4da0a 100644 --- a/roadcast/app.py +++ b/roadcast/app.py @@ -1,75 +1,35 @@ -from flask import Flask, request, jsonify -from dotenv import load_dotenv - -# Load environment variables from .env file -load_dotenv() - -app = Flask(__name__) -import os -import threading -import json - -# ML imports are lazy to avoid heavy imports on simple runs - - -@app.route('/get-data', methods=['GET']) -def get_data(): - # Example GET request handler - data = {"message": "Hello from Flask!"} - return jsonify(data) - -@app.route('/post-data', methods=['POST']) -def post_data(): - # Example POST request handler - content = request.json - # Process content or call AI model here - response = {"you_sent": content} - return jsonify(response) - - -@app.route('/train', methods=['POST']) -def train_endpoint(): - """Trigger training. Expects JSON: {"data_root": "path/to/data", "epochs": 3} - Training runs in a background thread and saves model to model.pth in repo root. - """ - payload = request.json or {} - data_root = payload.get('data_root') - epochs = int(payload.get('epochs', 3)) - if not data_root or not os.path.isdir(data_root): - return jsonify({"error": "data_root must be a valid directory path"}), 400 - - def _run_training(): - from train import train - train(data_root, epochs=epochs) - - t = threading.Thread(target=_run_training, daemon=True) - t.start() - return jsonify({"status": "training_started"}) - - @app.route('/predict', methods=['POST', 'GET']) def predict_endpoint(): """Predict route between two points given source and destination with lat and lon. + Expectation: - POST with JSON: {"source": {"lat": .., "lon": ..}, "destination": {"lat": .., "lon": ..}} - GET returns usage instructions for quick browser testing. """ + example_payload = { + "source": {"lat": 38.9, "lon": -77.0}, + "destination": {"lat": 38.95, "lon": -77.02} + } + info = "This endpoint expects a POST with JSON body." + note = ( + "Use POST to receive a prediction. Example: curl -X POST -H 'Content-Type: application/json' " + "-d '{\"source\": {\"lat\": 38.9, \"lon\": -77.0}, \"destination\": {\"lat\": 38.95, \"lon\": -77.02}}' " + "http://127.0.0.1:5000/predict" + ) + + if request.method == 'GET': - return jsonify({ - "info": "This endpoint expects a POST with JSON body.", - "example": { - "source": {"lat": 38.9, "lon": -77.0}, - "destination": {"lat": 38.95, "lon": -77.02} - }, - "note": "Use POST to receive a prediction. Example: curl -X POST -H 'Content-Type: application/json' -d '{\"source\": {\"lat\": 38.9, \"lon\": -77.0}, \"destination\": {\"lat\": 38.95, \"lon\": -77.02}}' http://127.0.0.1:5000/predict" - }), 200 + return jsonify({"info": info, "example": example_payload, "note": note}), 200 + data = request.json or {} source = data.get('source') destination = data.get('destination') if not source or not destination: return jsonify({"error": "both 'source' and 'destination' fields are required"}), 400 + + try: src_lat = float(source.get('lat')) src_lon = float(source.get('lon')) @@ -78,107 +38,93 @@ def predict_endpoint(): except (TypeError, ValueError): return jsonify({"error": "invalid lat or lon values; must be numbers"}), 400 + # Ensure compute_reroute exists and is callable try: - from openweather_client import compute_reroute + from openmeteo_client import compute_reroute except Exception as e: return jsonify({ - "error": "compute_reroute not found in openweather_client", + "error": "compute_reroute not found in openmeteo_client", "detail": str(e), - "hint": "Provide openweather_client.compute_reroute or implement a callable that accepts (src_lat, src_lon, dst_lat, dst_lon)" + "hint": "Provide openmeteo_client.compute_reroute " + "(Open-Meteo does not need an API key)" }), 500 + if not callable(compute_reroute): - return jsonify({"error": "openweather_client.compute_reroute is not callable"}), 500 + return jsonify({"error": "openmeteo_client.compute_reroute is not callable"}), 500 - # Call compute_reroute with fallback strategies + + def _extract_index(res): + if res is None: + return None + if isinstance(res, (int, float)): + return int(res) + if isinstance(res, dict): + for k in ('index', 'idx', 'cluster', 'cluster_idx', 'label_index', 'label_idx'): + if k in res: + try: + return int(res[k]) + except Exception: + return res[k] + return None + + + # Call compute_reroute (Open-Meteo requires no API key) try: + result = compute_reroute(src_lat, src_lon, dst_lat, dst_lon) + called_with = "positional" + + + diagnostics = {"type": type(result).__name__} try: - result = compute_reroute(src_lat, src_lon, dst_lat, dst_lon) - except TypeError: - # fallback: single payload dict - payload = {'source': {'lat': src_lat, 'lon': src_lon}, 'destination': {'lat': dst_lat, 'lon': dst_lon}} - result = compute_reroute(payload) + diagnostics["repr"] = repr(result)[:1000] + except Exception: + diagnostics["repr"] = "" - # Normalize response - if isinstance(result, dict): - return jsonify(result) + + # Normalize return types + if isinstance(result, (list, tuple)): + idx = None + for el in result: + idx = _extract_index(el) + if idx is not None: + break + prediction = {"items": list(result)} + index = idx + elif isinstance(result, dict): + index = _extract_index(result) + prediction = result + elif isinstance(result, (int, float, str)): + index = _extract_index(result) + prediction = {"value": result} else: - return jsonify({"result": result}) + index = None + prediction = {"value": result} + + + response_payload = { + "index": index, + "prediction": prediction, + "called_with": called_with, + "diagnostics": diagnostics, + "example": example_payload, + "info": info, + "note": note + } + + + # Add warning if no routing/index info found + expected_keys = ('route', 'path', 'distance', 'directions', 'index', 'idx', 'cluster') + if (not isinstance(prediction, dict) or not any(k in prediction for k in expected_keys)) and index is None: + response_payload["warning"] = ( + "No routing/index information returned from compute_reroute. " + "See diagnostics for details." + ) + + + return jsonify(response_payload) + + except Exception as e: - return jsonify({"error": "compute_reroute invocation failed", "detail": str(e)}), 500 - - -@app.route('/') -def home(): - return "

Welcome to the Flask App

Try /get-data or /health endpoints.

" - -@app.route('/predict-roadrisk', methods=['POST']) -def predict_roadrisk(): - """Proxy endpoint to predict a roadrisk cluster from lat/lon/datetime. - - Expects JSON body with: {"lat": 38.9, "lon": -77.0, "datetime": "2025-09-27T12:00:00", "roadrisk_url": "https://..."} - If roadrisk_url is not provided the endpoint will call OpenWeather OneCall (requires API key via OPENWEATHER_KEY env var). - """ - payload = request.json or {} - lat = payload.get('lat') - lon = payload.get('lon') - dt = payload.get('datetime') - street = payload.get('street', '') - roadrisk_url = payload.get('roadrisk_url') - # prefer explicit api_key in request, otherwise read from OPENWEATHER_API_KEY env var - api_key = payload.get('api_key') or os.environ.get('OPENWEATHER_API_KEY') - - if lat is None or lon is None: - return jsonify({"error": "lat and lon are required fields"}), 400 - - try: - from openweather_inference import predict_from_openweather - # pass api_key (may be None) to the inference helper; helper will raise if a key is required - res = predict_from_openweather( - lat, lon, - dt_iso=dt, - street=street, - api_key=api_key, - train_csv=os.path.join(os.getcwd(), 'data.csv'), - preprocess_meta=None, - model_path=os.path.join(os.getcwd(), 'model.pth'), - centers_path=os.path.join(os.getcwd(), 'kmeans_centers_all.npz'), - roadrisk_url=roadrisk_url - ) - return jsonify(res) - except Exception as e: - return jsonify({"error": str(e)}), 500 - - -@app.route('/health', methods=['GET']) -def health(): - """Return status of loaded ML artifacts (model, centers, preprocess_meta).""" - try: - from openweather_inference import init_inference - status = init_inference() - return jsonify({'ok': True, 'artifacts': status}) - except Exception as e: - return jsonify({'ok': False, 'error': str(e)}), 500 - -if __name__ == '__main__': - # eager load model/artifacts at startup (best-effort) - try: - from openweather_inference import init_inference - init_inference() - except Exception: - pass - app.run(debug=True) - -# @app.route('/post-data', methods=['POST']) -# def post_data(): -# content = request.json -# user_input = content.get('input') - -# # Example: Simple echo AI (replace with real AI model code) -# ai_response = f"AI received: {user_input}" - -# return jsonify({"response": ai_response}) -# ai_response = f"AI received: {user_input}" - -# return jsonify({"response": ai_response"}) + return jsonify({"error": "compute_reroute invocation failed", "detail": str(e)}), 500 \ No newline at end of file diff --git a/roadcast/openmeteo_client.py b/roadcast/openmeteo_client.py new file mode 100644 index 0000000..7c00822 --- /dev/null +++ b/roadcast/openmeteo_client.py @@ -0,0 +1,339 @@ +"""Open-Meteo historical weather client + simple road-risk heuristics. + +Backwards-compatible API: +- fetch_weather(lat, lon, params=None, api_key=None) +- fetch_road_risk(lat, lon, extra_params=None, api_key=None, roadrisk_url=None) +- get_risk_score(lat, lon, **fetch_kwargs) +- compute_reroute(...) +- compute_index_and_reroute(...) +""" +import os +from typing import Tuple, Dict, Any, Optional, Callable, List +import requests +import heapq +import math +from datetime import date, timedelta + +# Open-Meteo archive endpoint (no API key required) +BASE_ARCHIVE_URL = "https://archive-api.open-meteo.com/v1/archive" + + +def fetch_weather(lat: float, lon: float, params: Optional[dict] = None, api_key: Optional[str] = None) -> dict: + """Fetch historical weather from Open-Meteo archive API. + + Params may include 'start_date', 'end_date' (YYYY-MM-DD) and 'hourly' (comma-separated vars). + Defaults to yesterday..today and hourly variables useful for road risk. + (api_key parameter is accepted for compatibility but ignored.) + """ + if params is None: + params = {} + + today = date.today() + start = params.get("start_date", (today - timedelta(days=1)).isoformat()) + end = params.get("end_date", today.isoformat()) + + hourly = params.get( + "hourly", + ",".join(["temperature_2m", "relativehumidity_2m", "windspeed_10m", "precipitation", "weathercode"]) + ) + + query = { + "latitude": lat, + "longitude": lon, + "start_date": start, + "end_date": end, + "hourly": hourly, + "timezone": params.get("timezone", "UTC"), + } + + resp = requests.get(BASE_ARCHIVE_URL, params=query, timeout=15) + resp.raise_for_status() + return resp.json() + + +def fetch_road_risk( + lat: float, + lon: float, + extra_params: Optional[dict] = None, + api_key: Optional[str] = None, + roadrisk_url: Optional[str] = None +) -> Tuple[dict, Dict[str, Any]]: + """ + Compute a simple road risk estimation using Open-Meteo historical weather. + + Returns (raw_data, features) where features includes 'road_risk_score' (float). + api_key and roadrisk_url are accepted for backward compatibility but ignored. + """ + params = {} + if extra_params: + params.update(extra_params) + + # fetch weather via Open-Meteo archive + try: + data = fetch_weather(lat, lon, params=params) + except Exception as e: + features: Dict[str, Any] = {"road_risk_score": 0.0, "error": str(e)} + return {}, features + + hourly = data.get("hourly", {}) if isinstance(data, dict) else {} + + def _arr_mean(key): + arr = hourly.get(key) + if isinstance(arr, list) and arr: + valid = [float(x) for x in arr if x is not None] + return sum(valid) / max(1, len(valid)) if valid else None + return None + + def _arr_max(key): + arr = hourly.get(key) + if isinstance(arr, list) and arr: + valid = [float(x) for x in arr if x is not None] + return max(valid) if valid else None + return None + + precip_mean = _arr_mean("precipitation") + wind_mean = _arr_mean("windspeed_10m") + wind_max = _arr_max("windspeed_10m") + temp_mean = _arr_mean("temperature_2m") + humidity_mean = _arr_mean("relativehumidity_2m") + weathercodes = hourly.get("weathercode", []) + + # heuristic risk scoring: + risk = 0.0 + if precip_mean is not None: + risk += float(precip_mean) * 2.0 + if wind_mean is not None: + risk += float(wind_mean) * 0.1 + if wind_max is not None and float(wind_max) > 15.0: + risk += 1.0 + if humidity_mean is not None and float(humidity_mean) > 85.0: + risk += 0.5 + try: + # sample Open-Meteo weather codes that indicate precipitation/snow + if any(int(wc) in (51, 61, 63, 65, 80, 81, 82, 71, 73, 75, 85, 86) for wc in weathercodes if wc is not None): + risk += 1.0 + except Exception: + pass + + if not math.isfinite(risk): + risk = 0.0 + if risk < 0: + risk = 0.0 + + features: Dict[str, Any] = { + "precipitation_mean": float(precip_mean) if precip_mean is not None else None, + "wind_mean": float(wind_mean) if wind_mean is not None else None, + "wind_max": float(wind_max) if wind_max is not None else None, + "temp_mean": float(temp_mean) if temp_mean is not None else None, + "humidity_mean": float(humidity_mean) if humidity_mean is not None else None, + "road_risk_score": float(risk), + } + + # include some raw metadata if present + if "latitude" in data: + features["latitude"] = data.get("latitude") + if "longitude" in data: + features["longitude"] = data.get("longitude") + if "generationtime_ms" in data: + features["generationtime_ms"] = data.get("generationtime_ms") + + return data, features + + +def _haversine_km(a_lat: float, a_lon: float, b_lat: float, b_lon: float) -> float: + # returns distance in kilometers + R = 6371.0 + lat1, lon1, lat2, lon2 = map(math.radians, (a_lat, a_lon, b_lat, b_lon)) + dlat = lat2 - lat1 + dlon = lon2 - lon1 + h = math.sin(dlat / 2) ** 2 + math.cos(lat1) * math.cos(lat2) * math.sin(dlon / 2) ** 2 + return 2 * R * math.asin(min(1.0, math.sqrt(h))) + + +def risk_to_index(risk_score: float, max_risk: float = 10.0, num_bins: int = 10) -> int: + """ + Map a numeric risk_score to an integer index 1..num_bins (higher => more risky). + """ + if risk_score is None: + return 1 + r = float(risk_score) + if r <= 0: + return 1 + if r >= max_risk: + return num_bins + bin_width = max_risk / float(num_bins) + return int(r // bin_width) + 1 + + +def get_risk_score(lat: float, lon: float, **fetch_kwargs) -> float: + """Wrapper: calls fetch_road_risk and returns features['road_risk_score'] (float).""" + _, features = fetch_road_risk(lat, lon, extra_params=fetch_kwargs) + return float(features.get("road_risk_score", 0.0)) + + +def compute_reroute( + start_lat: float, + start_lon: float, + end_lat: float, + end_lon: float, + risk_provider: Callable[[float, float], float] = None, + n_lat: int = 20, + n_lon: int = 20, + distance_weight: float = 0.1, + max_calls: Optional[int] = None +) -> Dict[str, Any]: + """ + Plan a path from (start_lat, start_lon) to (end_lat, end_lon) that avoids risky areas. + Uses Dijkstra's algorithm over a lat/lon grid with cost = avg risk + distance_weight * distance. + """ + if risk_provider is None: + risk_provider = lambda lat, lon: get_risk_score(lat, lon) + + min_lat = min(start_lat, end_lat) + max_lat = max(start_lat, end_lat) + min_lon = min(start_lon, end_lon) + max_lon = max(start_lon, end_lon) + + lat_padding = (max_lat - min_lat) * 0.2 + lon_padding = (max_lon - min_lon) * 0.2 + min_lat -= lat_padding + max_lat += lat_padding + min_lon -= lon_padding + max_lon += lon_padding + + lat_step = (max_lat - min_lat) / (n_lat - 1) if n_lat > 1 else 0.0 + lon_step = (max_lon - min_lon) / (n_lon - 1) if n_lon > 1 else 0.0 + + coords = [] + for i in range(n_lat): + for j in range(n_lon): + coords.append((min_lat + i * lat_step, min_lon + j * lon_step)) + + risks = [] + calls = 0 + for lat, lon in coords: + if max_calls is not None and calls >= max_calls: + risks.append(float('inf')) + continue + try: + risk = risk_provider(lat, lon) + except Exception: + risk = float('inf') + risks.append(float(risk)) + calls += 1 + + def idx(i, j): + return i * n_lon + j + + def find_closest(lat, lon): + i = round((lat - min_lat) / (lat_step if lat_step != 0 else 1e-9)) + j = round((lon - min_lon) / (lon_step if lon_step != 0 else 1e-9)) + i = max(0, min(n_lat - 1, i)) + j = max(0, min(n_lon - 1, j)) + return idx(i, j) + + start_idx = find_closest(start_lat, start_lon) + end_idx = find_closest(end_lat, end_lon) + + N = len(coords) + dist = [math.inf] * N + prev = [None] * N + dist[start_idx] = 0.0 + pq = [(0.0, start_idx)] + + while pq: + cost, u = heapq.heappop(pq) + if cost > dist[u]: + continue + if u == end_idx: + break + + ui, uj = u // n_lon, u % n_lon + for di, dj in ((1, 0), (-1, 0), (0, 1), (0, -1)): + vi, vj = ui + di, uj + dj + if 0 <= vi < n_lat and 0 <= vj < n_lon: + v = idx(vi, vj) + if math.isinf(risks[v]) or math.isinf(risks[u]): + continue + lat_u, lon_u = coords[u] + lat_v, lon_v = coords[v] + d_km = _haversine_km(lat_u, lon_u, lat_v, lon_v) + edge_cost = (risks[u] + risks[v]) / 2 + distance_weight * d_km + new_cost = cost + edge_cost + if new_cost < dist[v]: + dist[v] = new_cost + prev[v] = u + heapq.heappush(pq, (new_cost, v)) + + if math.isinf(dist[end_idx]): + return { + "reroute_needed": False, + "reason": "no_path_found", + "start_coord": (start_lat, start_lon), + "end_coord": (end_lat, end_lon), + "calls_made": calls + } + + path_indices = [] + u = end_idx + while u is not None: + path_indices.append(u) + u = prev[u] + path_indices.reverse() + + path_coords = [coords[i] for i in path_indices] + return { + "reroute_needed": True, + "start_coord": (start_lat, start_lon), + "end_coord": (end_lat, end_lon), + "path": path_coords, + "total_cost": dist[end_idx], + "start_risk": risks[start_idx], + "end_risk": risks[end_idx], + "calls_made": calls, + "grid_shape": (n_lat, n_lon) + } + + +def compute_index_and_reroute(lat: float, + lon: float, + api_key: Optional[str] = None, + roadrisk_url: Optional[str] = None, + max_risk: float = 10.0, + num_bins: int = 10, + reroute_kwargs: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: + """ + Get road risk, map to index (1..num_bins), and attempt reroute. + reroute_kwargs forwarded to compute_reroute. + api_key/roadrisk_url accepted for compatibility but ignored by Open-Meteo implementation. + """ + if reroute_kwargs is None: + reroute_kwargs = {} + + data, features = fetch_road_risk(lat, lon, extra_params=reroute_kwargs, api_key=api_key, roadrisk_url=roadrisk_url) + road_risk = float(features.get("road_risk_score", 0.0)) + + accidents = features.get("accidents") or features.get("accident_count") + try: + if accidents is not None: + # fallback: map accident count to index if present + from .models import accidents_to_bucket # may not exist; wrapped in try + idx = accidents_to_bucket(int(accidents), max_count=20000, num_bins=num_bins) + else: + idx = risk_to_index(road_risk, max_risk=max_risk, num_bins=num_bins) + except Exception: + idx = risk_to_index(road_risk, max_risk=max_risk, num_bins=num_bins) + + def _rp(lat_, lon_): + return get_risk_score(lat_, lon_, api_key=api_key, roadrisk_url=roadrisk_url) + + reroute_info = compute_reroute(lat, lon, risk_provider=_rp, **reroute_kwargs) + return { + "lat": lat, + "lon": lon, + "index": int(idx), + "road_risk_score": road_risk, + "features": features, + "reroute": reroute_info, + "raw_roadrisk_response": data, + } diff --git a/roadcast/openmeteo_inference.py b/roadcast/openmeteo_inference.py new file mode 100644 index 0000000..541ae17 --- /dev/null +++ b/roadcast/openmeteo_inference.py @@ -0,0 +1,339 @@ +""" +Fetch OpenWeather data for a coordinate/time and run the trained MLP to predict the k-means cluster label. + +Usage examples: + # with training CSV provided to compute preprocessing stats: + python openweather_inference.py --lat 38.9 --lon -77.0 --datetime "2025-09-27T12:00:00" --train-csv data.csv --model model.pth --centers kmeans_centers_all.npz --api-key $OPENWEATHER_KEY + + # with precomputed preprocess meta (saved from training): + python openweather_inference.py --lat 38.9 --lon -77.0 --datetime "2025-09-27T12:00:00" --preprocess-meta preprocess_meta.npz --model model.pth --centers kmeans_centers_all.npz --api-key $OPENWEATHER_KEY + +Notes: +- The script uses the same feature-engineering helpers in `data.py` so the model sees identical inputs. +- You must either provide `--train-csv` (to compute feature columns & means/stds) or `--preprocess-meta` previously saved. +- Provide the OpenWeather API key via --api-key or the OPENWEATHER_KEY environment variable. +""" + +import os +import argparse +import json +from datetime import datetime +import numpy as np +import pandas as pd +import torch +import torch.nn.functional as F + +# reuse helpers from your repo +from data import _add_date_features, _add_latlon_bins, _add_hashed_street, CSVDataset +from inference import load_model + +# module-level caches to avoid reloading heavy artifacts per request +_CACHED_MODEL = None +_CACHED_IDX_TO_CLASS = None +_CACHED_CENTERS = None +_CACHED_PREPROCESS_META = None + + +OW_BASE = 'https://api.openweathermap.org/data/2.5/onecall' + + +def fetch_openmeteo(lat, lon, api_key, dt_iso=None): + """Fetch weather from OpenWeather One Call API for given lat/lon. If dt_iso provided, we fetch current+hourly and pick closest timestamp.""" + try: + import requests + except Exception: + raise RuntimeError('requests library is required to fetch OpenWeather data') + params = { + 'lat': float(lat), + 'lon': float(lon), + 'appid': api_key, + 'units': 'metric', + 'exclude': 'minutely,alerts' + } + r = requests.get(OW_BASE, params=params, timeout=10) + r.raise_for_status() + payload = r.json() + # if dt_iso provided, find nearest hourly data point + if dt_iso: + try: + target = pd.to_datetime(dt_iso) + except Exception: + target = None + best = None + if 'hourly' in payload and target is not None: + hours = payload['hourly'] + best = min(hours, key=lambda h: abs(pd.to_datetime(h['dt'], unit='s') - target)) + # convert keys to a flat dict with prefix 'ow_' + d = { + 'ow_temp': best.get('temp'), + 'ow_feels_like': best.get('feels_like'), + 'ow_pressure': best.get('pressure'), + 'ow_humidity': best.get('humidity'), + 'ow_wind_speed': best.get('wind_speed'), + 'ow_clouds': best.get('clouds'), + 'ow_pop': best.get('pop'), + } + return d + # fallback: use current + cur = payload.get('current', {}) + d = { + 'ow_temp': cur.get('temp'), + 'ow_feels_like': cur.get('feels_like'), + 'ow_pressure': cur.get('pressure'), + 'ow_humidity': cur.get('humidity'), + 'ow_wind_speed': cur.get('wind_speed'), + 'ow_clouds': cur.get('clouds'), + 'ow_pop': None, + } + return d + + +def fetch_roadrisk(roadrisk_url, api_key=None): + """Fetch the RoadRisk endpoint (expects JSON). If `api_key` is provided, we'll attach it as a query param if the URL has no key. + + We flatten top-level numeric fields into `rr_*` keys for the feature row. + """ + # if api_key provided and url does not contain appid, append it + try: + import requests + except Exception: + raise RuntimeError('requests library is required to fetch RoadRisk data') + url = roadrisk_url + if api_key and 'appid=' not in roadrisk_url: + sep = '&' if '?' in roadrisk_url else '?' + url = f"{roadrisk_url}{sep}appid={api_key}" + + r = requests.get(url, timeout=10) + r.raise_for_status() + payload = r.json() + # flatten numeric top-level fields + out = {} + if isinstance(payload, dict): + for k, v in payload.items(): + if isinstance(v, (int, float)): + out[f'rr_{k}'] = v + # if nested objects contain simple numeric fields, pull them too (one level deep) + elif isinstance(v, dict): + for kk, vv in v.items(): + if isinstance(vv, (int, float)): + out[f'rr_{k}_{kk}'] = vv + return out + + +def build_row(lat, lon, dt_iso=None, street=None, extra_weather=None): + """Construct a single-row DataFrame with columns expected by the training pipeline. + + It intentionally uses column names the original `data.py` looked for (REPORTDATE, LATITUDE, LONGITUDE, ADDRESS, etc.). + """ + row = {} + # date column matching common names + row['REPORTDATE'] = dt_iso if dt_iso else datetime.utcnow().isoformat() + row['LATITUDE'] = lat + row['LONGITUDE'] = lon + row['ADDRESS'] = street if street else '' + # include some injury/fatality placeholders that the label generator expects + row['INJURIES'] = 0 + row['FATALITIES'] = 0 + # include weather features returned by OpenWeather (prefixed 'ow_') + if extra_weather: + for k, v in extra_weather.items(): + row[k] = v + return pd.DataFrame([row]) + + +def prepare_features(df_row, train_csv=None, preprocess_meta=None, feature_engineer=True, lat_lon_bins=20): + """Given a one-row DataFrame, apply same feature engineering and standardization as training. + + If preprocess_meta is provided (npz), use it. Otherwise train_csv must be provided to compute stats. + Returns a torch.FloatTensor of shape (1, input_dim) and the feature_columns list. + """ + # apply feature engineering helpers + if feature_engineer: + try: + _add_date_features(df_row) + except Exception: + pass + try: + _add_latlon_bins(df_row, bins=lat_lon_bins) + except Exception: + pass + try: + _add_hashed_street(df_row) + except Exception: + pass + + # if meta provided, load feature_columns, means, stds + if preprocess_meta and os.path.exists(preprocess_meta): + meta = np.load(preprocess_meta, allow_pickle=True) + feature_columns = meta['feature_columns'].tolist() + means = meta['means'] + stds = meta['stds'] + else: + if not train_csv: + raise ValueError('Either preprocess_meta or train_csv must be provided to derive feature stats') + # instantiate a CSVDataset on train_csv (feature_engineer True) to reuse its preprocessing + ds = CSVDataset(train_csv, feature_columns=None, label_column='label', generate_labels=True, n_buckets=10, label_method='kmeans', label_store=None, feature_engineer=feature_engineer, lat_lon_bins=lat_lon_bins, nrows=None) + feature_columns = ds.feature_columns + means = ds.feature_means + stds = ds.feature_stds + # save meta for reuse + np.savez_compressed('preprocess_meta.npz', feature_columns=np.array(feature_columns, dtype=object), means=means, stds=stds) + print('Saved preprocess_meta.npz') + + # ensure all feature columns exist in df_row + for c in feature_columns: + if c not in df_row.columns: + df_row[c] = 0 + + # coerce and fill using means + features_df = df_row[feature_columns].apply(lambda c: pd.to_numeric(c, errors='coerce')) + features_df = features_df.fillna(pd.Series(means, index=feature_columns)).fillna(0.0) + # standardize + features_np = (features_df.values - means) / (stds + 1e-6) + import torch + return torch.tensor(features_np, dtype=torch.float32), feature_columns + + +def predict_from_openmeteo(lat, lon, dt_iso=None, street=None, api_key=None, train_csv=None, preprocess_meta=None, model_path='model.pth', centers_path='kmeans_centers_all.npz', roadrisk_url=None): + api_key = api_key or os.environ.get('OPENWEATHER_KEY') + if api_key is None: + raise ValueError('OpenWeather API key required via --api-key or OPENWEATHER_KEY env var') + + # gather weather/road-risk features + weather = {} + if roadrisk_url: + try: + rr = fetch_roadrisk(roadrisk_url, api_key=api_key) + weather.update(rr) + except Exception as e: + print('Warning: failed to fetch roadrisk URL:', e) + else: + try: + ow = fetch_openmeteo(lat, lon, api_key, dt_iso=dt_iso) + weather.update(ow) + except Exception as e: + print('Warning: failed to fetch openweather:', e) + + df_row = build_row(lat, lon, dt_iso=dt_iso, street=street, extra_weather=weather) + x_tensor, feature_columns = prepare_features(df_row, train_csv=train_csv, preprocess_meta=preprocess_meta) + + # load model (infer num_classes from centers file if possible) + global _CACHED_MODEL, _CACHED_IDX_TO_CLASS, _CACHED_CENTERS, _CACHED_PREPROCESS_META + + # ensure we have preprocess_meta available (prefer supplied path, otherwise fallback to saved file) + if preprocess_meta is None: + candidate = os.path.join(os.getcwd(), 'preprocess_meta.npz') + if os.path.exists(candidate): + preprocess_meta = candidate + + # load centers (cache across requests) + if _CACHED_CENTERS is None: + if centers_path and os.path.exists(centers_path): + try: + npz = np.load(centers_path) + _CACHED_CENTERS = npz['centers'] + except Exception: + _CACHED_CENTERS = None + else: + _CACHED_CENTERS = None + + num_classes = _CACHED_CENTERS.shape[0] if _CACHED_CENTERS is not None else 10 + + # load model once and cache it + if _CACHED_MODEL is None: + try: + _CACHED_MODEL, _CACHED_IDX_TO_CLASS = load_model(model_path, device=None, in_channels=3, num_classes=num_classes) + device = 'cuda' if torch.cuda.is_available() else 'cpu' + _CACHED_MODEL.to(device) + except Exception as e: + raise + model = _CACHED_MODEL + idx_to_class = _CACHED_IDX_TO_CLASS + device = 'cuda' if torch.cuda.is_available() else 'cpu' + x_tensor = x_tensor.to(device) + with torch.no_grad(): + logits = model(x_tensor) + probs = F.softmax(logits, dim=1).cpu().numpy()[0] + pred_idx = int(probs.argmax()) + confidence = float(probs.max()) + + # optionally provide cluster centroid info + centroid = _CACHED_CENTERS[pred_idx] if _CACHED_CENTERS is not None else None + + return { + 'pred_cluster': int(pred_idx), + 'confidence': confidence, + 'probabilities': probs.tolist(), + 'centroid': centroid.tolist() if centroid is not None else None, + 'feature_columns': feature_columns, + 'used_preprocess_meta': preprocess_meta + } + + +def init_inference(model_path='model.pth', centers_path='kmeans_centers_all.npz', preprocess_meta=None): + """Eagerly load model, centers, and preprocess_meta into module-level caches. + + This is intended to be called at app startup to surface load errors early and avoid + per-request disk IO. The function is best-effort and will print warnings if artifacts + are missing. + """ + global _CACHED_MODEL, _CACHED_IDX_TO_CLASS, _CACHED_CENTERS, _CACHED_PREPROCESS_META + + # prefer existing saved preprocess_meta if not explicitly provided + if preprocess_meta is None: + candidate = os.path.join(os.getcwd(), 'preprocess_meta.npz') + if os.path.exists(candidate): + preprocess_meta = candidate + + _CACHED_PREPROCESS_META = preprocess_meta + + # load centers + if _CACHED_CENTERS is None: + if centers_path and os.path.exists(centers_path): + try: + npz = np.load(centers_path) + _CACHED_CENTERS = npz['centers'] + print(f'Loaded centers from {centers_path}') + except Exception as e: + print('Warning: failed to load centers:', e) + _CACHED_CENTERS = None + else: + print('No centers file found at', centers_path) + _CACHED_CENTERS = None + + num_classes = _CACHED_CENTERS.shape[0] if _CACHED_CENTERS is not None else 10 + + # load model + if _CACHED_MODEL is None: + try: + _CACHED_MODEL, _CACHED_IDX_TO_CLASS = load_model(model_path, device=None, in_channels=3, num_classes=num_classes) + device = 'cuda' if torch.cuda.is_available() else 'cpu' + _CACHED_MODEL.to(device) + print(f'Loaded model from {model_path}') + except Exception as e: + print('Warning: failed to load model:', e) + _CACHED_MODEL = None + + return { + 'model_loaded': _CACHED_MODEL is not None, + 'centers_loaded': _CACHED_CENTERS is not None, + 'preprocess_meta': _CACHED_PREPROCESS_META + } + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--lat', type=float, required=True) + parser.add_argument('--lon', type=float, required=True) + parser.add_argument('--datetime', default=None, help='ISO datetime string to query hourly weather (optional)') + parser.add_argument('--street', default='') + parser.add_argument('--api-key', default=None, help='OpenWeather API key or use OPENWEATHER_KEY env var') + parser.add_argument('--train-csv', default=None, help='Path to training CSV to compute preprocessing stats (optional if --preprocess-meta provided)') + parser.add_argument('--preprocess-meta', default=None, help='Path to precomputed preprocess_meta.npz (optional)') + parser.add_argument('--model', default='model.pth') + parser.add_argument('--centers', default='kmeans_centers_all.npz') + parser.add_argument('--roadrisk-url', default=None, help='Optional custom RoadRisk API URL (if provided, will be queried instead of OneCall)') + args = parser.parse_args() + + out = predict_from_openmeteo(args.lat, args.lon, dt_iso=args.datetime, street=args.street, api_key=args.api_key, train_csv=args.train_csv, preprocess_meta=args.preprocess_meta, model_path=args.model, centers_path=args.centers, roadrisk_url=args.roadrisk_url) + print(json.dumps(out, indent=2))