114 lines
4.8 KiB
Python
114 lines
4.8 KiB
Python
import os
|
|
import argparse
|
|
import pandas as pd
|
|
import numpy as np
|
|
import time
|
|
|
|
import openweather_inference as owi
|
|
|
|
|
|
def find_column(df_cols, candidates):
|
|
cmap = {c.lower(): c for c in df_cols}
|
|
for cand in candidates:
|
|
if cand.lower() in cmap:
|
|
return cmap[cand.lower()]
|
|
return None
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('csv', help='Path to data CSV (e.g., data.csv)')
|
|
parser.add_argument('--out', default='inference_results.csv')
|
|
parser.add_argument('--lat-col', default=None)
|
|
parser.add_argument('--lon-col', default=None)
|
|
parser.add_argument('--date-col', default=None)
|
|
parser.add_argument('--model', default='model.pth')
|
|
parser.add_argument('--centers', default='kmeans_centers_all.npz')
|
|
parser.add_argument('--preprocess-meta', default=None)
|
|
parser.add_argument('--api-key', default=None)
|
|
parser.add_argument('--live', action='store_true', help='If set, call external RoadRisk/OpenWeather per row')
|
|
parser.add_argument('--roadrisk-url', default=None, help='Optional per-request RoadRisk URL to use when --live')
|
|
parser.add_argument('--subset', type=int, default=0, help='Process only first N rows for testing')
|
|
args = parser.parse_args()
|
|
|
|
df = pd.read_csv(args.csv, low_memory=False)
|
|
nrows = args.subset if args.subset and args.subset > 0 else len(df)
|
|
df = df.iloc[:nrows].copy()
|
|
|
|
# find sensible columns
|
|
lat_col = args.lat_col or find_column(df.columns, ['latitude', 'lat', 'mpdlatitude'])
|
|
lon_col = args.lon_col or find_column(df.columns, ['longitude', 'lon', 'mpdlongitude'])
|
|
date_col = args.date_col or find_column(df.columns, ['report_dat', 'reportdate', 'fromdate', 'lastupdatedate', 'date', 'occur_date'])
|
|
|
|
if lat_col is None or lon_col is None:
|
|
raise SystemExit('Could not find latitude/longitude columns automatically. Pass --lat-col and --lon-col.')
|
|
|
|
print(f'Using lat column: {lat_col}, lon column: {lon_col}, date column: {date_col}')
|
|
|
|
# eager init caches
|
|
status = owi.init_inference(model_path=args.model, centers_path=args.centers, preprocess_meta=args.preprocess_meta)
|
|
print('init status:', status)
|
|
|
|
results = []
|
|
t0 = time.time()
|
|
for i, row in df.iterrows():
|
|
lat = row.get(lat_col)
|
|
lon = row.get(lon_col)
|
|
dt = row.get(date_col) if date_col else None
|
|
|
|
try:
|
|
if args.live:
|
|
# call the full pipeline which may hit remote API
|
|
out = owi.predict_from_openweather(lat, lon, dt_iso=dt, street=None, api_key=args.api_key, train_csv=None, preprocess_meta=args.preprocess_meta, model_path=args.model, centers_path=args.centers, roadrisk_url=args.roadrisk_url)
|
|
else:
|
|
# local-only path: build row, prepare features using preprocess_meta, and run cached model
|
|
df_row = owi.build_row(lat, lon, dt_iso=dt, street=None, extra_weather=None)
|
|
x_tensor, feature_columns = owi.prepare_features(df_row, train_csv=None, preprocess_meta=args.preprocess_meta)
|
|
# ensure model cached
|
|
if owi._CACHED_MODEL is None:
|
|
owi.init_inference(model_path=args.model, centers_path=args.centers, preprocess_meta=args.preprocess_meta)
|
|
model = owi._CACHED_MODEL
|
|
centers = owi._CACHED_CENTERS
|
|
device = 'cuda' if __import__('torch').cuda.is_available() else 'cpu'
|
|
model.to(device)
|
|
xt = x_tensor.to(device)
|
|
import torch
|
|
import torch.nn.functional as F
|
|
with torch.no_grad():
|
|
logits = model(xt)
|
|
probs = F.softmax(logits, dim=1).cpu().numpy()[0]
|
|
pred_idx = int(probs.argmax())
|
|
confidence = float(probs.max())
|
|
out = {'pred_cluster': pred_idx, 'confidence': confidence, 'probabilities': probs.tolist(), 'centroid': centers[pred_idx].tolist() if centers is not None else None, 'feature_columns': feature_columns}
|
|
except Exception as e:
|
|
out = {'error': str(e)}
|
|
|
|
# combine row and output into flat result
|
|
res = {
|
|
'orig_index': i,
|
|
'lat': lat,
|
|
'lon': lon,
|
|
'datetime': str(dt),
|
|
}
|
|
if 'error' in out:
|
|
res.update({'error': out['error']})
|
|
else:
|
|
res.update({
|
|
'pred_cluster': int(out.get('pred_cluster')),
|
|
'confidence': float(out.get('confidence')),
|
|
})
|
|
results.append(res)
|
|
|
|
if (len(results) % 50) == 0:
|
|
print(f'Processed {len(results)}/{nrows} rows...')
|
|
|
|
elapsed = time.time() - t0
|
|
print(f'Finished {len(results)} rows in {elapsed:.2f}s')
|
|
out_df = pd.DataFrame(results)
|
|
out_df.to_csv(args.out, index=False)
|
|
print('Wrote', args.out)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|