added the model
This commit is contained in:
339
roadcast/openweather_inference.py
Normal file
339
roadcast/openweather_inference.py
Normal file
@@ -0,0 +1,339 @@
|
||||
"""
|
||||
Fetch OpenWeather data for a coordinate/time and run the trained MLP to predict the k-means cluster label.
|
||||
|
||||
Usage examples:
|
||||
# with training CSV provided to compute preprocessing stats:
|
||||
python openweather_inference.py --lat 38.9 --lon -77.0 --datetime "2025-09-27T12:00:00" --train-csv data.csv --model model.pth --centers kmeans_centers_all.npz --api-key $OPENWEATHER_KEY
|
||||
|
||||
# with precomputed preprocess meta (saved from training):
|
||||
python openweather_inference.py --lat 38.9 --lon -77.0 --datetime "2025-09-27T12:00:00" --preprocess-meta preprocess_meta.npz --model model.pth --centers kmeans_centers_all.npz --api-key $OPENWEATHER_KEY
|
||||
|
||||
Notes:
|
||||
- The script uses the same feature-engineering helpers in `data.py` so the model sees identical inputs.
|
||||
- You must either provide `--train-csv` (to compute feature columns & means/stds) or `--preprocess-meta` previously saved.
|
||||
- Provide the OpenWeather API key via --api-key or the OPENWEATHER_KEY environment variable.
|
||||
"""
|
||||
|
||||
import os
|
||||
import argparse
|
||||
import json
|
||||
from datetime import datetime
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
# reuse helpers from your repo
|
||||
from data import _add_date_features, _add_latlon_bins, _add_hashed_street, CSVDataset
|
||||
from inference import load_model
|
||||
|
||||
# module-level caches to avoid reloading heavy artifacts per request
|
||||
_CACHED_MODEL = None
|
||||
_CACHED_IDX_TO_CLASS = None
|
||||
_CACHED_CENTERS = None
|
||||
_CACHED_PREPROCESS_META = None
|
||||
|
||||
|
||||
OW_BASE = 'https://api.openweathermap.org/data/2.5/onecall'
|
||||
|
||||
|
||||
def fetch_openweather(lat, lon, api_key, dt_iso=None):
|
||||
"""Fetch weather from OpenWeather One Call API for given lat/lon. If dt_iso provided, we fetch current+hourly and pick closest timestamp."""
|
||||
try:
|
||||
import requests
|
||||
except Exception:
|
||||
raise RuntimeError('requests library is required to fetch OpenWeather data')
|
||||
params = {
|
||||
'lat': float(lat),
|
||||
'lon': float(lon),
|
||||
'appid': api_key,
|
||||
'units': 'metric',
|
||||
'exclude': 'minutely,alerts'
|
||||
}
|
||||
r = requests.get(OW_BASE, params=params, timeout=10)
|
||||
r.raise_for_status()
|
||||
payload = r.json()
|
||||
# if dt_iso provided, find nearest hourly data point
|
||||
if dt_iso:
|
||||
try:
|
||||
target = pd.to_datetime(dt_iso)
|
||||
except Exception:
|
||||
target = None
|
||||
best = None
|
||||
if 'hourly' in payload and target is not None:
|
||||
hours = payload['hourly']
|
||||
best = min(hours, key=lambda h: abs(pd.to_datetime(h['dt'], unit='s') - target))
|
||||
# convert keys to a flat dict with prefix 'ow_'
|
||||
d = {
|
||||
'ow_temp': best.get('temp'),
|
||||
'ow_feels_like': best.get('feels_like'),
|
||||
'ow_pressure': best.get('pressure'),
|
||||
'ow_humidity': best.get('humidity'),
|
||||
'ow_wind_speed': best.get('wind_speed'),
|
||||
'ow_clouds': best.get('clouds'),
|
||||
'ow_pop': best.get('pop'),
|
||||
}
|
||||
return d
|
||||
# fallback: use current
|
||||
cur = payload.get('current', {})
|
||||
d = {
|
||||
'ow_temp': cur.get('temp'),
|
||||
'ow_feels_like': cur.get('feels_like'),
|
||||
'ow_pressure': cur.get('pressure'),
|
||||
'ow_humidity': cur.get('humidity'),
|
||||
'ow_wind_speed': cur.get('wind_speed'),
|
||||
'ow_clouds': cur.get('clouds'),
|
||||
'ow_pop': None,
|
||||
}
|
||||
return d
|
||||
|
||||
|
||||
def fetch_roadrisk(roadrisk_url, api_key=None):
|
||||
"""Fetch the RoadRisk endpoint (expects JSON). If `api_key` is provided, we'll attach it as a query param if the URL has no key.
|
||||
|
||||
We flatten top-level numeric fields into `rr_*` keys for the feature row.
|
||||
"""
|
||||
# if api_key provided and url does not contain appid, append it
|
||||
try:
|
||||
import requests
|
||||
except Exception:
|
||||
raise RuntimeError('requests library is required to fetch RoadRisk data')
|
||||
url = roadrisk_url
|
||||
if api_key and 'appid=' not in roadrisk_url:
|
||||
sep = '&' if '?' in roadrisk_url else '?'
|
||||
url = f"{roadrisk_url}{sep}appid={api_key}"
|
||||
|
||||
r = requests.get(url, timeout=10)
|
||||
r.raise_for_status()
|
||||
payload = r.json()
|
||||
# flatten numeric top-level fields
|
||||
out = {}
|
||||
if isinstance(payload, dict):
|
||||
for k, v in payload.items():
|
||||
if isinstance(v, (int, float)):
|
||||
out[f'rr_{k}'] = v
|
||||
# if nested objects contain simple numeric fields, pull them too (one level deep)
|
||||
elif isinstance(v, dict):
|
||||
for kk, vv in v.items():
|
||||
if isinstance(vv, (int, float)):
|
||||
out[f'rr_{k}_{kk}'] = vv
|
||||
return out
|
||||
|
||||
|
||||
def build_row(lat, lon, dt_iso=None, street=None, extra_weather=None):
|
||||
"""Construct a single-row DataFrame with columns expected by the training pipeline.
|
||||
|
||||
It intentionally uses column names the original `data.py` looked for (REPORTDATE, LATITUDE, LONGITUDE, ADDRESS, etc.).
|
||||
"""
|
||||
row = {}
|
||||
# date column matching common names
|
||||
row['REPORTDATE'] = dt_iso if dt_iso else datetime.utcnow().isoformat()
|
||||
row['LATITUDE'] = lat
|
||||
row['LONGITUDE'] = lon
|
||||
row['ADDRESS'] = street if street else ''
|
||||
# include some injury/fatality placeholders that the label generator expects
|
||||
row['INJURIES'] = 0
|
||||
row['FATALITIES'] = 0
|
||||
# include weather features returned by OpenWeather (prefixed 'ow_')
|
||||
if extra_weather:
|
||||
for k, v in extra_weather.items():
|
||||
row[k] = v
|
||||
return pd.DataFrame([row])
|
||||
|
||||
|
||||
def prepare_features(df_row, train_csv=None, preprocess_meta=None, feature_engineer=True, lat_lon_bins=20):
|
||||
"""Given a one-row DataFrame, apply same feature engineering and standardization as training.
|
||||
|
||||
If preprocess_meta is provided (npz), use it. Otherwise train_csv must be provided to compute stats.
|
||||
Returns a torch.FloatTensor of shape (1, input_dim) and the feature_columns list.
|
||||
"""
|
||||
# apply feature engineering helpers
|
||||
if feature_engineer:
|
||||
try:
|
||||
_add_date_features(df_row)
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
_add_latlon_bins(df_row, bins=lat_lon_bins)
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
_add_hashed_street(df_row)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# if meta provided, load feature_columns, means, stds
|
||||
if preprocess_meta and os.path.exists(preprocess_meta):
|
||||
meta = np.load(preprocess_meta, allow_pickle=True)
|
||||
feature_columns = meta['feature_columns'].tolist()
|
||||
means = meta['means']
|
||||
stds = meta['stds']
|
||||
else:
|
||||
if not train_csv:
|
||||
raise ValueError('Either preprocess_meta or train_csv must be provided to derive feature stats')
|
||||
# instantiate a CSVDataset on train_csv (feature_engineer True) to reuse its preprocessing
|
||||
ds = CSVDataset(train_csv, feature_columns=None, label_column='label', generate_labels=False, n_buckets=10, label_method='kmeans', label_store=None, feature_engineer=feature_engineer, lat_lon_bins=lat_lon_bins, nrows=None)
|
||||
feature_columns = ds.feature_columns
|
||||
means = ds.feature_means
|
||||
stds = ds.feature_stds
|
||||
# save meta for reuse
|
||||
np.savez_compressed('preprocess_meta.npz', feature_columns=np.array(feature_columns, dtype=object), means=means, stds=stds)
|
||||
print('Saved preprocess_meta.npz')
|
||||
|
||||
# ensure all feature columns exist in df_row
|
||||
for c in feature_columns:
|
||||
if c not in df_row.columns:
|
||||
df_row[c] = 0
|
||||
|
||||
# coerce and fill using means
|
||||
features_df = df_row[feature_columns].apply(lambda c: pd.to_numeric(c, errors='coerce'))
|
||||
features_df = features_df.fillna(pd.Series(means, index=feature_columns)).fillna(0.0)
|
||||
# standardize
|
||||
features_np = (features_df.values - means) / (stds + 1e-6)
|
||||
import torch
|
||||
return torch.tensor(features_np, dtype=torch.float32), feature_columns
|
||||
|
||||
|
||||
def predict_from_openweather(lat, lon, dt_iso=None, street=None, api_key=None, train_csv=None, preprocess_meta=None, model_path='model.pth', centers_path='kmeans_centers_all.npz', roadrisk_url=None):
|
||||
api_key = api_key or os.environ.get('OPENWEATHER_KEY')
|
||||
if api_key is None:
|
||||
raise ValueError('OpenWeather API key required via --api-key or OPENWEATHER_KEY env var')
|
||||
|
||||
# gather weather/road-risk features
|
||||
weather = {}
|
||||
if roadrisk_url:
|
||||
try:
|
||||
rr = fetch_roadrisk(roadrisk_url, api_key=api_key)
|
||||
weather.update(rr)
|
||||
except Exception as e:
|
||||
print('Warning: failed to fetch roadrisk URL:', e)
|
||||
else:
|
||||
try:
|
||||
ow = fetch_openweather(lat, lon, api_key, dt_iso=dt_iso)
|
||||
weather.update(ow)
|
||||
except Exception as e:
|
||||
print('Warning: failed to fetch openweather:', e)
|
||||
|
||||
df_row = build_row(lat, lon, dt_iso=dt_iso, street=street, extra_weather=weather)
|
||||
x_tensor, feature_columns = prepare_features(df_row, train_csv=train_csv, preprocess_meta=preprocess_meta)
|
||||
|
||||
# load model (infer num_classes from centers file if possible)
|
||||
global _CACHED_MODEL, _CACHED_IDX_TO_CLASS, _CACHED_CENTERS, _CACHED_PREPROCESS_META
|
||||
|
||||
# ensure we have preprocess_meta available (prefer supplied path, otherwise fallback to saved file)
|
||||
if preprocess_meta is None:
|
||||
candidate = os.path.join(os.getcwd(), 'preprocess_meta.npz')
|
||||
if os.path.exists(candidate):
|
||||
preprocess_meta = candidate
|
||||
|
||||
# load centers (cache across requests)
|
||||
if _CACHED_CENTERS is None:
|
||||
if centers_path and os.path.exists(centers_path):
|
||||
try:
|
||||
npz = np.load(centers_path)
|
||||
_CACHED_CENTERS = npz['centers']
|
||||
except Exception:
|
||||
_CACHED_CENTERS = None
|
||||
else:
|
||||
_CACHED_CENTERS = None
|
||||
|
||||
num_classes = _CACHED_CENTERS.shape[0] if _CACHED_CENTERS is not None else 10
|
||||
|
||||
# load model once and cache it
|
||||
if _CACHED_MODEL is None:
|
||||
try:
|
||||
_CACHED_MODEL, _CACHED_IDX_TO_CLASS = load_model(model_path, device=None, in_channels=3, num_classes=num_classes)
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
_CACHED_MODEL.to(device)
|
||||
except Exception as e:
|
||||
raise
|
||||
model = _CACHED_MODEL
|
||||
idx_to_class = _CACHED_IDX_TO_CLASS
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
x_tensor = x_tensor.to(device)
|
||||
with torch.no_grad():
|
||||
logits = model(x_tensor)
|
||||
probs = F.softmax(logits, dim=1).cpu().numpy()[0]
|
||||
pred_idx = int(probs.argmax())
|
||||
confidence = float(probs.max())
|
||||
|
||||
# optionally provide cluster centroid info
|
||||
centroid = _CACHED_CENTERS[pred_idx] if _CACHED_CENTERS is not None else None
|
||||
|
||||
return {
|
||||
'pred_cluster': int(pred_idx),
|
||||
'confidence': confidence,
|
||||
'probabilities': probs.tolist(),
|
||||
'centroid': centroid.tolist() if centroid is not None else None,
|
||||
'feature_columns': feature_columns,
|
||||
'used_preprocess_meta': preprocess_meta
|
||||
}
|
||||
|
||||
|
||||
def init_inference(model_path='model.pth', centers_path='kmeans_centers_all.npz', preprocess_meta=None):
|
||||
"""Eagerly load model, centers, and preprocess_meta into module-level caches.
|
||||
|
||||
This is intended to be called at app startup to surface load errors early and avoid
|
||||
per-request disk IO. The function is best-effort and will print warnings if artifacts
|
||||
are missing.
|
||||
"""
|
||||
global _CACHED_MODEL, _CACHED_IDX_TO_CLASS, _CACHED_CENTERS, _CACHED_PREPROCESS_META
|
||||
|
||||
# prefer existing saved preprocess_meta if not explicitly provided
|
||||
if preprocess_meta is None:
|
||||
candidate = os.path.join(os.getcwd(), 'preprocess_meta.npz')
|
||||
if os.path.exists(candidate):
|
||||
preprocess_meta = candidate
|
||||
|
||||
_CACHED_PREPROCESS_META = preprocess_meta
|
||||
|
||||
# load centers
|
||||
if _CACHED_CENTERS is None:
|
||||
if centers_path and os.path.exists(centers_path):
|
||||
try:
|
||||
npz = np.load(centers_path)
|
||||
_CACHED_CENTERS = npz['centers']
|
||||
print(f'Loaded centers from {centers_path}')
|
||||
except Exception as e:
|
||||
print('Warning: failed to load centers:', e)
|
||||
_CACHED_CENTERS = None
|
||||
else:
|
||||
print('No centers file found at', centers_path)
|
||||
_CACHED_CENTERS = None
|
||||
|
||||
num_classes = _CACHED_CENTERS.shape[0] if _CACHED_CENTERS is not None else 10
|
||||
|
||||
# load model
|
||||
if _CACHED_MODEL is None:
|
||||
try:
|
||||
_CACHED_MODEL, _CACHED_IDX_TO_CLASS = load_model(model_path, device=None, in_channels=3, num_classes=num_classes)
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
_CACHED_MODEL.to(device)
|
||||
print(f'Loaded model from {model_path}')
|
||||
except Exception as e:
|
||||
print('Warning: failed to load model:', e)
|
||||
_CACHED_MODEL = None
|
||||
|
||||
return {
|
||||
'model_loaded': _CACHED_MODEL is not None,
|
||||
'centers_loaded': _CACHED_CENTERS is not None,
|
||||
'preprocess_meta': _CACHED_PREPROCESS_META
|
||||
}
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--lat', type=float, required=True)
|
||||
parser.add_argument('--lon', type=float, required=True)
|
||||
parser.add_argument('--datetime', default=None, help='ISO datetime string to query hourly weather (optional)')
|
||||
parser.add_argument('--street', default='')
|
||||
parser.add_argument('--api-key', default=None, help='OpenWeather API key or use OPENWEATHER_KEY env var')
|
||||
parser.add_argument('--train-csv', default=None, help='Path to training CSV to compute preprocessing stats (optional if --preprocess-meta provided)')
|
||||
parser.add_argument('--preprocess-meta', default=None, help='Path to precomputed preprocess_meta.npz (optional)')
|
||||
parser.add_argument('--model', default='model.pth')
|
||||
parser.add_argument('--centers', default='kmeans_centers_all.npz')
|
||||
parser.add_argument('--roadrisk-url', default=None, help='Optional custom RoadRisk API URL (if provided, will be queried instead of OneCall)')
|
||||
args = parser.parse_args()
|
||||
|
||||
out = predict_from_openweather(args.lat, args.lon, dt_iso=args.datetime, street=args.street, api_key=args.api_key, train_csv=args.train_csv, preprocess_meta=args.preprocess_meta, model_path=args.model, centers_path=args.centers, roadrisk_url=args.roadrisk_url)
|
||||
print(json.dumps(out, indent=2))
|
||||
Reference in New Issue
Block a user