Files
VTHacks13/roadcast/run_batch_inference.py
samarthjain2023 0df2b0019b added the model
2025-09-27 12:14:26 -04:00

114 lines
4.8 KiB
Python

import os
import argparse
import pandas as pd
import numpy as np
import time
import openweather_inference as owi
def find_column(df_cols, candidates):
cmap = {c.lower(): c for c in df_cols}
for cand in candidates:
if cand.lower() in cmap:
return cmap[cand.lower()]
return None
def main():
parser = argparse.ArgumentParser()
parser.add_argument('csv', help='Path to data CSV (e.g., data.csv)')
parser.add_argument('--out', default='inference_results.csv')
parser.add_argument('--lat-col', default=None)
parser.add_argument('--lon-col', default=None)
parser.add_argument('--date-col', default=None)
parser.add_argument('--model', default='model.pth')
parser.add_argument('--centers', default='kmeans_centers_all.npz')
parser.add_argument('--preprocess-meta', default=None)
parser.add_argument('--api-key', default=None)
parser.add_argument('--live', action='store_true', help='If set, call external RoadRisk/OpenWeather per row')
parser.add_argument('--roadrisk-url', default=None, help='Optional per-request RoadRisk URL to use when --live')
parser.add_argument('--subset', type=int, default=0, help='Process only first N rows for testing')
args = parser.parse_args()
df = pd.read_csv(args.csv, low_memory=False)
nrows = args.subset if args.subset and args.subset > 0 else len(df)
df = df.iloc[:nrows].copy()
# find sensible columns
lat_col = args.lat_col or find_column(df.columns, ['latitude', 'lat', 'mpdlatitude'])
lon_col = args.lon_col or find_column(df.columns, ['longitude', 'lon', 'mpdlongitude'])
date_col = args.date_col or find_column(df.columns, ['report_dat', 'reportdate', 'fromdate', 'lastupdatedate', 'date', 'occur_date'])
if lat_col is None or lon_col is None:
raise SystemExit('Could not find latitude/longitude columns automatically. Pass --lat-col and --lon-col.')
print(f'Using lat column: {lat_col}, lon column: {lon_col}, date column: {date_col}')
# eager init caches
status = owi.init_inference(model_path=args.model, centers_path=args.centers, preprocess_meta=args.preprocess_meta)
print('init status:', status)
results = []
t0 = time.time()
for i, row in df.iterrows():
lat = row.get(lat_col)
lon = row.get(lon_col)
dt = row.get(date_col) if date_col else None
try:
if args.live:
# call the full pipeline which may hit remote API
out = owi.predict_from_openweather(lat, lon, dt_iso=dt, street=None, api_key=args.api_key, train_csv=None, preprocess_meta=args.preprocess_meta, model_path=args.model, centers_path=args.centers, roadrisk_url=args.roadrisk_url)
else:
# local-only path: build row, prepare features using preprocess_meta, and run cached model
df_row = owi.build_row(lat, lon, dt_iso=dt, street=None, extra_weather=None)
x_tensor, feature_columns = owi.prepare_features(df_row, train_csv=None, preprocess_meta=args.preprocess_meta)
# ensure model cached
if owi._CACHED_MODEL is None:
owi.init_inference(model_path=args.model, centers_path=args.centers, preprocess_meta=args.preprocess_meta)
model = owi._CACHED_MODEL
centers = owi._CACHED_CENTERS
device = 'cuda' if __import__('torch').cuda.is_available() else 'cpu'
model.to(device)
xt = x_tensor.to(device)
import torch
import torch.nn.functional as F
with torch.no_grad():
logits = model(xt)
probs = F.softmax(logits, dim=1).cpu().numpy()[0]
pred_idx = int(probs.argmax())
confidence = float(probs.max())
out = {'pred_cluster': pred_idx, 'confidence': confidence, 'probabilities': probs.tolist(), 'centroid': centers[pred_idx].tolist() if centers is not None else None, 'feature_columns': feature_columns}
except Exception as e:
out = {'error': str(e)}
# combine row and output into flat result
res = {
'orig_index': i,
'lat': lat,
'lon': lon,
'datetime': str(dt),
}
if 'error' in out:
res.update({'error': out['error']})
else:
res.update({
'pred_cluster': int(out.get('pred_cluster')),
'confidence': float(out.get('confidence')),
})
results.append(res)
if (len(results) % 50) == 0:
print(f'Processed {len(results)}/{nrows} rows...')
elapsed = time.time() - t0
print(f'Finished {len(results)} rows in {elapsed:.2f}s')
out_df = pd.DataFrame(results)
out_df.to_csv(args.out, index=False)
print('Wrote', args.out)
if __name__ == '__main__':
main()