added the model
This commit is contained in:
113
roadcast/run_batch_inference.py
Normal file
113
roadcast/run_batch_inference.py
Normal file
@@ -0,0 +1,113 @@
|
||||
import os
|
||||
import argparse
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import time
|
||||
|
||||
import openweather_inference as owi
|
||||
|
||||
|
||||
def find_column(df_cols, candidates):
|
||||
cmap = {c.lower(): c for c in df_cols}
|
||||
for cand in candidates:
|
||||
if cand.lower() in cmap:
|
||||
return cmap[cand.lower()]
|
||||
return None
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('csv', help='Path to data CSV (e.g., data.csv)')
|
||||
parser.add_argument('--out', default='inference_results.csv')
|
||||
parser.add_argument('--lat-col', default=None)
|
||||
parser.add_argument('--lon-col', default=None)
|
||||
parser.add_argument('--date-col', default=None)
|
||||
parser.add_argument('--model', default='model.pth')
|
||||
parser.add_argument('--centers', default='kmeans_centers_all.npz')
|
||||
parser.add_argument('--preprocess-meta', default=None)
|
||||
parser.add_argument('--api-key', default=None)
|
||||
parser.add_argument('--live', action='store_true', help='If set, call external RoadRisk/OpenWeather per row')
|
||||
parser.add_argument('--roadrisk-url', default=None, help='Optional per-request RoadRisk URL to use when --live')
|
||||
parser.add_argument('--subset', type=int, default=0, help='Process only first N rows for testing')
|
||||
args = parser.parse_args()
|
||||
|
||||
df = pd.read_csv(args.csv, low_memory=False)
|
||||
nrows = args.subset if args.subset and args.subset > 0 else len(df)
|
||||
df = df.iloc[:nrows].copy()
|
||||
|
||||
# find sensible columns
|
||||
lat_col = args.lat_col or find_column(df.columns, ['latitude', 'lat', 'mpdlatitude'])
|
||||
lon_col = args.lon_col or find_column(df.columns, ['longitude', 'lon', 'mpdlongitude'])
|
||||
date_col = args.date_col or find_column(df.columns, ['report_dat', 'reportdate', 'fromdate', 'lastupdatedate', 'date', 'occur_date'])
|
||||
|
||||
if lat_col is None or lon_col is None:
|
||||
raise SystemExit('Could not find latitude/longitude columns automatically. Pass --lat-col and --lon-col.')
|
||||
|
||||
print(f'Using lat column: {lat_col}, lon column: {lon_col}, date column: {date_col}')
|
||||
|
||||
# eager init caches
|
||||
status = owi.init_inference(model_path=args.model, centers_path=args.centers, preprocess_meta=args.preprocess_meta)
|
||||
print('init status:', status)
|
||||
|
||||
results = []
|
||||
t0 = time.time()
|
||||
for i, row in df.iterrows():
|
||||
lat = row.get(lat_col)
|
||||
lon = row.get(lon_col)
|
||||
dt = row.get(date_col) if date_col else None
|
||||
|
||||
try:
|
||||
if args.live:
|
||||
# call the full pipeline which may hit remote API
|
||||
out = owi.predict_from_openweather(lat, lon, dt_iso=dt, street=None, api_key=args.api_key, train_csv=None, preprocess_meta=args.preprocess_meta, model_path=args.model, centers_path=args.centers, roadrisk_url=args.roadrisk_url)
|
||||
else:
|
||||
# local-only path: build row, prepare features using preprocess_meta, and run cached model
|
||||
df_row = owi.build_row(lat, lon, dt_iso=dt, street=None, extra_weather=None)
|
||||
x_tensor, feature_columns = owi.prepare_features(df_row, train_csv=None, preprocess_meta=args.preprocess_meta)
|
||||
# ensure model cached
|
||||
if owi._CACHED_MODEL is None:
|
||||
owi.init_inference(model_path=args.model, centers_path=args.centers, preprocess_meta=args.preprocess_meta)
|
||||
model = owi._CACHED_MODEL
|
||||
centers = owi._CACHED_CENTERS
|
||||
device = 'cuda' if __import__('torch').cuda.is_available() else 'cpu'
|
||||
model.to(device)
|
||||
xt = x_tensor.to(device)
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
with torch.no_grad():
|
||||
logits = model(xt)
|
||||
probs = F.softmax(logits, dim=1).cpu().numpy()[0]
|
||||
pred_idx = int(probs.argmax())
|
||||
confidence = float(probs.max())
|
||||
out = {'pred_cluster': pred_idx, 'confidence': confidence, 'probabilities': probs.tolist(), 'centroid': centers[pred_idx].tolist() if centers is not None else None, 'feature_columns': feature_columns}
|
||||
except Exception as e:
|
||||
out = {'error': str(e)}
|
||||
|
||||
# combine row and output into flat result
|
||||
res = {
|
||||
'orig_index': i,
|
||||
'lat': lat,
|
||||
'lon': lon,
|
||||
'datetime': str(dt),
|
||||
}
|
||||
if 'error' in out:
|
||||
res.update({'error': out['error']})
|
||||
else:
|
||||
res.update({
|
||||
'pred_cluster': int(out.get('pred_cluster')),
|
||||
'confidence': float(out.get('confidence')),
|
||||
})
|
||||
results.append(res)
|
||||
|
||||
if (len(results) % 50) == 0:
|
||||
print(f'Processed {len(results)}/{nrows} rows...')
|
||||
|
||||
elapsed = time.time() - t0
|
||||
print(f'Finished {len(results)} rows in {elapsed:.2f}s')
|
||||
out_df = pd.DataFrame(results)
|
||||
out_df.to_csv(args.out, index=False)
|
||||
print('Wrote', args.out)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
Reference in New Issue
Block a user