fixed post request
This commit is contained in:
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
204
roadcast/app.py
204
roadcast/app.py
@@ -1,6 +1,8 @@
|
||||
from flask import Flask, request, jsonify
|
||||
from dotenv import load_dotenv
|
||||
from openmeteo_client import compute_index
|
||||
from train import compute_index
|
||||
from models import load_model
|
||||
from models import MLP
|
||||
|
||||
# Load environment variables from .env file
|
||||
load_dotenv()
|
||||
@@ -9,6 +11,7 @@ app = Flask(__name__)
|
||||
import os
|
||||
import threading
|
||||
import json
|
||||
import numpy as np # added
|
||||
|
||||
# ML imports are lazy to avoid heavy imports on simple runs
|
||||
|
||||
@@ -60,69 +63,30 @@ def health():
|
||||
except Exception as e:
|
||||
return jsonify({'ok': False, 'error': str(e)}), 500
|
||||
|
||||
if __name__ == '__main__':
|
||||
# eager load model/artifacts at startup (best-effort)
|
||||
try:
|
||||
from openmeteo_inference import init_inference
|
||||
init_inference()
|
||||
except Exception:
|
||||
pass
|
||||
app.run(debug=True)
|
||||
|
||||
@app.route('/predict', methods=['POST', 'GET'])
|
||||
def predict_endpoint():
|
||||
"""Predict route between two points given source and destination with lat and lon.
|
||||
|
||||
Expectation:
|
||||
- POST with JSON: {"source": {"lat": .., "lon": ..}, "destination": {"lat": .., "lon": ..}}
|
||||
- GET returns usage instructions for quick browser testing.
|
||||
GET uses an example payload; POST accepts JSON with 'source' and 'destination'.
|
||||
Both methods run the same model/index logic and return the same response format.
|
||||
"""
|
||||
example_payload = {
|
||||
"source": {"lat": 38.9, "lon": -77.0},
|
||||
"destination": {"lat": 38.95, "lon": -77.02}
|
||||
}
|
||||
info = "This endpoint expects a POST with JSON body."
|
||||
|
||||
info = "This endpoint expects a POST with JSON body (GET returns example payload)."
|
||||
note = (
|
||||
"Use POST to receive a prediction. Example: curl -X POST -H 'Content-Type: application/json' "
|
||||
"-d '{\"source\": {\"lat\": 38.9, \"lon\": -77.0}, \"destination\": {\"lat\": 38.95, \"lon\": -77.02}}' "
|
||||
"http://127.0.0.1:5000/predict"
|
||||
)
|
||||
|
||||
# unify request data: GET -> example, POST -> request.json
|
||||
if request.method == 'GET':
|
||||
# Return the same structure as POST but without prediction
|
||||
# response_payload = {
|
||||
# "index": None,
|
||||
# "prediction": {},
|
||||
# "called_with": "GET",
|
||||
# "diagnostics": {},
|
||||
# "example": example_payload,
|
||||
# "info": info,
|
||||
# "note": note
|
||||
# }
|
||||
data = example_payload
|
||||
else:
|
||||
data = request.json or {}
|
||||
|
||||
# For GET request, compute the road risk index using the example coordinates
|
||||
src_lat = example_payload['source']['lat']
|
||||
src_lon = example_payload['source']['lon']
|
||||
dst_lat = example_payload['destination']['lat']
|
||||
dst_lon = example_payload['destination']['lon']
|
||||
|
||||
# Use the compute_index function to get the road risk index
|
||||
index = compute_index(src_lat, src_lon)
|
||||
|
||||
# Prepare the response payload
|
||||
response_payload = {
|
||||
"index": index, # The computed index here
|
||||
"prediction": {},
|
||||
"called_with": "GET",
|
||||
"diagnostics": {},
|
||||
"example": example_payload,
|
||||
"info": info,
|
||||
"note": note
|
||||
}
|
||||
return jsonify(response_payload), 200
|
||||
|
||||
# POST request logic
|
||||
data = request.json or {}
|
||||
source = data.get('source')
|
||||
destination = data.get('destination')
|
||||
if not source or not destination:
|
||||
@@ -136,89 +100,81 @@ def predict_endpoint():
|
||||
except (TypeError, ValueError):
|
||||
return jsonify({"error": "invalid lat or lon values; must be numbers"}), 400
|
||||
|
||||
# Ensure compute_reroute exists and is callable
|
||||
# load model (loader infers architecture from checkpoint)
|
||||
try:
|
||||
from openmeteo_client import compute_reroute
|
||||
model = load_model('model.pth', MLP)
|
||||
except Exception as e:
|
||||
return jsonify({
|
||||
"error": "compute_reroute not found in openmeteo_client",
|
||||
"detail": str(e),
|
||||
"hint": "Provide openmeteo_client.compute_reroute "
|
||||
"(Open-Meteo does not need an API key)"
|
||||
}), 500
|
||||
return jsonify({"error": "model load failed", "detail": str(e)}), 500
|
||||
|
||||
if not callable(compute_reroute):
|
||||
return jsonify({"error": "openmeteo_client.compute_reroute is not callable"}), 500
|
||||
|
||||
def _extract_index(res):
|
||||
if res is None:
|
||||
return None
|
||||
if isinstance(res, (int, float)):
|
||||
return int(res)
|
||||
if isinstance(res, dict):
|
||||
for k in ('index', 'idx', 'cluster', 'cluster_idx', 'label_index', 'label_idx'):
|
||||
if k in res:
|
||||
try:
|
||||
return int(res[k])
|
||||
except Exception:
|
||||
return res[k]
|
||||
return None
|
||||
|
||||
# Call compute_reroute (Open-Meteo requires no API key)
|
||||
# infer expected input dim from model first linear weight
|
||||
try:
|
||||
result = compute_reroute(src_lat, src_lon, dst_lat, dst_lon)
|
||||
called_with = "positional"
|
||||
input_dim = None
|
||||
for v in model.state_dict().values():
|
||||
if getattr(v, "dim", None) and v.dim() == 2:
|
||||
input_dim = int(v.shape[1])
|
||||
break
|
||||
if input_dim is None:
|
||||
input_dim = 2
|
||||
except Exception:
|
||||
input_dim = 2
|
||||
|
||||
diagnostics = {"type": type(result).__name__}
|
||||
# build feature vector of correct length and populate lat/lon using preprocess meta if available
|
||||
feature_vector = np.zeros(int(input_dim), dtype=float)
|
||||
meta_path = os.path.join(os.getcwd(), 'preprocess_meta.npz')
|
||||
if os.path.exists(meta_path):
|
||||
try:
|
||||
diagnostics["repr"] = repr(result)[:1000]
|
||||
meta = np.load(meta_path, allow_pickle=True)
|
||||
cols = [str(x) for x in meta['feature_columns'].tolist()]
|
||||
means = meta.get('means')
|
||||
if means is not None and len(means) == input_dim:
|
||||
feature_vector[:] = means
|
||||
col_lower = [c.lower() for c in cols]
|
||||
if 'lat' in col_lower:
|
||||
feature_vector[col_lower.index('lat')] = src_lat
|
||||
elif 'latitude' in col_lower:
|
||||
feature_vector[col_lower.index('latitude')] = src_lat
|
||||
else:
|
||||
feature_vector[0] = src_lat
|
||||
if 'lon' in col_lower:
|
||||
feature_vector[col_lower.index('lon')] = src_lon
|
||||
elif 'longitude' in col_lower:
|
||||
feature_vector[col_lower.index('longitude')] = src_lon
|
||||
else:
|
||||
if input_dim > 1:
|
||||
feature_vector[1] = src_lon
|
||||
except Exception:
|
||||
diagnostics["repr"] = "<unrepr-able>"
|
||||
|
||||
# Normalize return types
|
||||
if isinstance(result, (list, tuple)):
|
||||
idx = None
|
||||
for el in result:
|
||||
idx = _extract_index(el)
|
||||
if idx is not None:
|
||||
break
|
||||
prediction = {"items": list(result)}
|
||||
index = idx
|
||||
elif isinstance(result, dict):
|
||||
index = _extract_index(result)
|
||||
prediction = result
|
||||
elif isinstance(result, (int, float, str)):
|
||||
index = _extract_index(result)
|
||||
prediction = {"value": result}
|
||||
else:
|
||||
index = None
|
||||
prediction = {"value": result}
|
||||
|
||||
response_payload = {
|
||||
"index": index,
|
||||
"prediction": prediction,
|
||||
"called_with": called_with,
|
||||
"diagnostics": diagnostics,
|
||||
"example": example_payload,
|
||||
"info": info,
|
||||
"note": note
|
||||
}
|
||||
|
||||
# Add warning if no routing/index info found
|
||||
expected_keys = ('route', 'path', 'distance', 'directions', 'index', 'idx', 'cluster')
|
||||
if (not isinstance(prediction, dict) or not any(k in prediction for k in expected_keys)) and index is None:
|
||||
response_payload["warning"] = (
|
||||
"No routing/index information returned from compute_reroute. "
|
||||
"See diagnostics for details."
|
||||
)
|
||||
|
||||
return jsonify(response_payload), 200
|
||||
feature_vector[:] = 0.0
|
||||
feature_vector[0] = src_lat
|
||||
if input_dim > 1:
|
||||
feature_vector[1] = src_lon
|
||||
else:
|
||||
feature_vector[0] = src_lat
|
||||
if input_dim > 1:
|
||||
feature_vector[1] = src_lon
|
||||
|
||||
# compute index using model
|
||||
try:
|
||||
index = compute_index(model, feature_vector)
|
||||
except Exception as e:
|
||||
return jsonify({
|
||||
"error": "Error processing the request",
|
||||
"detail": str(e)
|
||||
}), 500
|
||||
return jsonify({"error": "compute_index failed", "detail": str(e)}), 500
|
||||
|
||||
except Exception as e:
|
||||
return jsonify({"error": "compute_reroute invocation failed", "detail": str(e)}), 500
|
||||
response_payload = {
|
||||
"index": index,
|
||||
"prediction": {},
|
||||
"called_with": request.method,
|
||||
"diagnostics": {"input_dim": int(input_dim)},
|
||||
"example": example_payload,
|
||||
"info": info,
|
||||
"note": note
|
||||
}
|
||||
|
||||
return jsonify(response_payload), 200
|
||||
|
||||
if __name__ == '__main__':
|
||||
# eager load model/artifacts at startup (best-effort)
|
||||
try:
|
||||
from openmeteo_inference import init_inference
|
||||
init_inference()
|
||||
except Exception:
|
||||
pass
|
||||
app.run(debug=True)
|
||||
@@ -1,26 +1,4 @@
|
||||
train the model:
|
||||
python train.py data.csv --model-type mlp --generate-labels --label-method kmeans --n-buckets 50 --hidden-dims 512,256 --epochs 8 --batch-size 256 --feature-engineer --weight-decay 1e-5 --seed 42
|
||||
|
||||
python train.py data.csv --model-type mlp --generate-labels --label-method kmeans --n-buckets 50 --hidden-dims 1024,512 --epochs 12 --batch-size 256 --lr 1e-3 --lr-step-size 4 --lr-gamma 0.5 --feature-engineer --weight-decay 1e-5 --seed 42
|
||||
|
||||
# train with outputs saved to output/
|
||||
python train.py data.csv --model-type mlp --generate-labels --label-method kmeans --n-buckets 50 --hidden-dims 512,256 --epochs 8 --batch-size 256 --feature-engineer --weight-decay 1e-5 --seed 42 --output-dir output/
|
||||
|
||||
# evaluate and visualize:
|
||||
python evaluate_and_visualize.py \
|
||||
--checkpoint path/to/checkpoint.pt \
|
||||
--data data.csv \
|
||||
--label-col original_label_column_name \
|
||||
--batch-size 256 \
|
||||
--sample-index 5 \
|
||||
--plot
|
||||
|
||||
# evaluate
|
||||
python evaluate_and_visualize.py --checkpoint output/model.pth --data data.csv --label-col label --plot --sample-index 5
|
||||
|
||||
# If you used generated labels during training and train.py saved metadata,
|
||||
# the evaluator will prefer generated labels saved in label_info.json inside the checkpoint dir.
|
||||
|
||||
# fetch weather (placeholder)
|
||||
# from openweather_client import fetch_road_risk
|
||||
# print(fetch_road_risk(37.7749, -122.4194))
|
||||
python train.py data.csv --model-type mlp --generate-labels --label-method kmeans --n-buckets 50 --hidden-dims 1024,512 --epochs 12 --batch-size 256 --lr 1e-3 --lr-step-size 4 --lr-gamma 0.5 --feature-engineer --weight-decay 1e-5 --seed 42
|
||||
Binary file not shown.
@@ -1,13 +1,147 @@
|
||||
# import torch
|
||||
# import torch.nn as nn
|
||||
# import math
|
||||
# from typing import Union, Iterable
|
||||
# import numpy as np
|
||||
# import torch as _torch
|
||||
|
||||
# def accidents_to_bucket(count: Union[int, float, Iterable],
|
||||
# max_count: int = 20000,
|
||||
# num_bins: int = 10) -> Union[int, list, _torch.Tensor, np.ndarray]:
|
||||
# """
|
||||
# Map accident counts to simple buckets 1..num_bins (equal-width).
|
||||
# Example: max_count=20000, num_bins=10 -> bin width = 2000
|
||||
# 0-1999 -> 1, 2000-3999 -> 2, ..., 18000-20000 -> 10
|
||||
|
||||
# Args:
|
||||
# count: single value or iterable (list/numpy/torch). Values <=0 map to 1, values >= max_count map to num_bins.
|
||||
# max_count: expected maximum count (top of highest bin).
|
||||
# num_bins: number of buckets (default 10).
|
||||
|
||||
# Returns:
|
||||
# Same type as input (int for scalar, list/numpy/torch for iterables) with values in 1..num_bins.
|
||||
# """
|
||||
# width = max_count / float(num_bins)
|
||||
# def _bucket_scalar(x):
|
||||
# # clamp
|
||||
# x = 0.0 if x is None else float(x)
|
||||
# if x <= 0:
|
||||
# return 1
|
||||
# if x >= max_count:
|
||||
# return num_bins
|
||||
# return int(x // width) + 1
|
||||
|
||||
# # scalar int/float
|
||||
# if isinstance(count, (int, float)):
|
||||
# return _bucket_scalar(count)
|
||||
|
||||
# # torch tensor
|
||||
# if isinstance(count, _torch.Tensor):
|
||||
# x = count.clone().float()
|
||||
# x = _torch.clamp(x, min=0.0, max=float(max_count))
|
||||
# buckets = (x // width).to(_torch.long) + 1
|
||||
# buckets = _torch.clamp(buckets, min=1, max=num_bins)
|
||||
# return buckets
|
||||
|
||||
# # numpy array
|
||||
# if isinstance(count, np.ndarray):
|
||||
# x = np.clip(count.astype(float), 0.0, float(max_count))
|
||||
# buckets = (x // width).astype(int) + 1
|
||||
# return np.clip(buckets, 1, num_bins)
|
||||
|
||||
# # generic iterable -> list
|
||||
# if isinstance(count, Iterable):
|
||||
# return [ _bucket_scalar(float(x)) for x in count ]
|
||||
|
||||
# # fallback
|
||||
# return _bucket_scalar(float(count))
|
||||
|
||||
|
||||
# class SimpleCNN(nn.Module):
|
||||
# """A small CNN for image classification (adjustable). Automatically computes flattened size."""
|
||||
# def __init__(self, in_channels=3, num_classes=10, input_size=(3, 224, 224)):
|
||||
# super().__init__()
|
||||
# self.features = nn.Sequential(
|
||||
# nn.Conv2d(in_channels, 32, kernel_size=3, padding=1),
|
||||
# nn.ReLU(),
|
||||
# nn.MaxPool2d(2),
|
||||
# nn.Conv2d(32, 64, kernel_size=3, padding=1),
|
||||
# nn.ReLU(),
|
||||
# nn.MaxPool2d(2),
|
||||
# )
|
||||
# # compute flatten size using a dummy tensor
|
||||
# with torch.no_grad():
|
||||
# dummy = torch.zeros(1, *input_size)
|
||||
# feat = self.features(dummy)
|
||||
# # flat_features was previously computed as:
|
||||
# # int(feat.numel() / feat.shape[0])
|
||||
# # Explanation:
|
||||
# # feat.shape == (N, C, H, W) (for image inputs)
|
||||
# # feat.numel() == N * C * H * W
|
||||
# # dividing by N (feat.shape[0]) yields C * H * W, i.e. flattened size per sample
|
||||
# # Clearer alternative using tensor shape:
|
||||
# flat_features = int(torch.prod(torch.tensor(feat.shape[1:])).item())
|
||||
# # If you need the linear index mapping for coordinates (c, h, w):
|
||||
# # idx = c * (H * W) + h * W + w
|
||||
|
||||
# self.classifier = nn.Sequential(
|
||||
# nn.Flatten(),
|
||||
# nn.Linear(flat_features, 256),
|
||||
# nn.ReLU(),
|
||||
# nn.Dropout(0.5),
|
||||
# nn.Linear(256, num_classes),
|
||||
# )
|
||||
|
||||
# def forward(self, x):
|
||||
# x = self.features(x)
|
||||
# x = self.classifier(x)
|
||||
# return x
|
||||
|
||||
|
||||
# class MLP(nn.Module):
|
||||
# """Simple MLP for tabular CSV data classification."""
|
||||
# def __init__(self, input_dim, hidden_dims=(256, 128), num_classes=2):
|
||||
# super().__init__()
|
||||
# layers = []
|
||||
# prev = input_dim
|
||||
# for h in hidden_dims:
|
||||
# layers.append(nn.Linear(prev, h))
|
||||
# layers.append(nn.ReLU())
|
||||
# layers.append(nn.Dropout(0.2))
|
||||
# prev = h
|
||||
# layers.append(nn.Linear(prev, num_classes))
|
||||
# self.net = nn.Sequential(*layers)
|
||||
|
||||
# def forward(self, x):
|
||||
# return self.net(x)
|
||||
|
||||
|
||||
# def create_model(device=None, in_channels=3, num_classes=10, input_size=(3, 224, 224), model_type='cnn', input_dim=None, hidden_dims=None):
|
||||
# if model_type == 'mlp':
|
||||
# if input_dim is None:
|
||||
# raise ValueError('input_dim is required for mlp model_type')
|
||||
# if hidden_dims is None:
|
||||
# model = MLP(input_dim=input_dim, num_classes=num_classes)
|
||||
# else:
|
||||
# model = MLP(input_dim=input_dim, hidden_dims=hidden_dims, num_classes=num_classes)
|
||||
# else:
|
||||
# model = SimpleCNN(in_channels=in_channels, num_classes=num_classes, input_size=input_size)
|
||||
|
||||
# if device:
|
||||
# model.to(device)
|
||||
# return model
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import math
|
||||
from typing import Union, Iterable
|
||||
import numpy as np
|
||||
import torch as _torch
|
||||
import os
|
||||
|
||||
# Retaining the existing `accidents_to_bucket` function for accident categorization
|
||||
def accidents_to_bucket(count: Union[int, float, Iterable],
|
||||
max_count: int = 20000,
|
||||
num_bins: int = 10) -> Union[int, list, _torch.Tensor, np.ndarray]:
|
||||
num_bins: int = 10) -> Union[int, list, torch.Tensor, np.ndarray]:
|
||||
"""
|
||||
Map accident counts to simple buckets 1..num_bins (equal-width).
|
||||
Example: max_count=20000, num_bins=10 -> bin width = 2000
|
||||
@@ -36,11 +170,11 @@ def accidents_to_bucket(count: Union[int, float, Iterable],
|
||||
return _bucket_scalar(count)
|
||||
|
||||
# torch tensor
|
||||
if isinstance(count, _torch.Tensor):
|
||||
if isinstance(count, torch.Tensor):
|
||||
x = count.clone().float()
|
||||
x = _torch.clamp(x, min=0.0, max=float(max_count))
|
||||
buckets = (x // width).to(_torch.long) + 1
|
||||
buckets = _torch.clamp(buckets, min=1, max=num_bins)
|
||||
x = torch.clamp(x, min=0.0, max=float(max_count))
|
||||
buckets = (x // width).to(torch.long) + 1
|
||||
buckets = torch.clamp(buckets, min=1, max=num_bins)
|
||||
return buckets
|
||||
|
||||
# numpy array
|
||||
@@ -57,8 +191,8 @@ def accidents_to_bucket(count: Union[int, float, Iterable],
|
||||
return _bucket_scalar(float(count))
|
||||
|
||||
|
||||
# SimpleCNN: CNN model for image classification
|
||||
class SimpleCNN(nn.Module):
|
||||
"""A small CNN for image classification (adjustable). Automatically computes flattened size."""
|
||||
def __init__(self, in_channels=3, num_classes=10, input_size=(3, 224, 224)):
|
||||
super().__init__()
|
||||
self.features = nn.Sequential(
|
||||
@@ -69,21 +203,11 @@ class SimpleCNN(nn.Module):
|
||||
nn.ReLU(),
|
||||
nn.MaxPool2d(2),
|
||||
)
|
||||
# compute flatten size using a dummy tensor
|
||||
with torch.no_grad():
|
||||
dummy = torch.zeros(1, *input_size)
|
||||
feat = self.features(dummy)
|
||||
# flat_features was previously computed as:
|
||||
# int(feat.numel() / feat.shape[0])
|
||||
# Explanation:
|
||||
# feat.shape == (N, C, H, W) (for image inputs)
|
||||
# feat.numel() == N * C * H * W
|
||||
# dividing by N (feat.shape[0]) yields C * H * W, i.e. flattened size per sample
|
||||
# Clearer alternative using tensor shape:
|
||||
flat_features = int(torch.prod(torch.tensor(feat.shape[1:])).item())
|
||||
# If you need the linear index mapping for coordinates (c, h, w):
|
||||
# idx = c * (H * W) + h * W + w
|
||||
|
||||
|
||||
self.classifier = nn.Sequential(
|
||||
nn.Flatten(),
|
||||
nn.Linear(flat_features, 256),
|
||||
@@ -100,7 +224,7 @@ class SimpleCNN(nn.Module):
|
||||
|
||||
class MLP(nn.Module):
|
||||
"""Simple MLP for tabular CSV data classification."""
|
||||
def __init__(self, input_dim, hidden_dims=(256, 128), num_classes=2):
|
||||
def __init__(self, input_dim=58, hidden_dims=(1024, 512, 50), num_classes=10):
|
||||
super().__init__()
|
||||
layers = []
|
||||
prev = input_dim
|
||||
@@ -116,17 +240,148 @@ class MLP(nn.Module):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
def load_model(model_path, model_class, input_dim=None):
|
||||
"""
|
||||
Load the model weights from the given path and initialize the model class.
|
||||
|
||||
Behavior:
|
||||
- If the checkpoint contains 'model_config', use it to build the model.
|
||||
- Otherwise infer input_dim / hidden_dims / num_classes from the state_dict shapes.
|
||||
- model_class must be MLP or SimpleCNN; for MLP input_dim may be inferred if not provided.
|
||||
"""
|
||||
import torch
|
||||
if not os.path.exists(model_path):
|
||||
raise FileNotFoundError(f"model file not found: {model_path}")
|
||||
|
||||
ckpt = torch.load(model_path, map_location=torch.device('cpu'))
|
||||
|
||||
# locate state dict
|
||||
state = None
|
||||
for k in ('model_state_dict', 'state_dict', 'model'):
|
||||
if k in ckpt and isinstance(ckpt[k], dict):
|
||||
state = ckpt[k]
|
||||
break
|
||||
if state is None:
|
||||
# maybe the file directly contains the state_dict
|
||||
if isinstance(ckpt, dict) and any(k.endswith('.weight') for k in ckpt.keys()):
|
||||
state = ckpt
|
||||
else:
|
||||
raise ValueError("No state_dict found in checkpoint")
|
||||
|
||||
# prefer explicit model_config if present
|
||||
model_config = ckpt.get('model_config') or ckpt.get('config') or {}
|
||||
|
||||
# helper to infer MLP params from state_dict if no config provided
|
||||
def _infer_mlp_from_state(state_dict):
|
||||
# collect net.*.weight keys (MLP uses 'net' module)
|
||||
weight_items = []
|
||||
for k in state_dict.keys():
|
||||
if k.endswith('.weight') and k.startswith('net.'):
|
||||
try:
|
||||
idx = int(k.split('.')[1])
|
||||
except Exception:
|
||||
continue
|
||||
weight_items.append((idx, k))
|
||||
if not weight_items:
|
||||
# fallback: take all weight-like keys in order
|
||||
weight_items = [(i, k) for i, k in enumerate(sorted([k for k in state_dict.keys() if k.endswith('.weight')]))]
|
||||
weight_items.sort()
|
||||
shapes = [tuple(state_dict[k].shape) for _, k in weight_items]
|
||||
# shapes are (out, in) for each Linear
|
||||
if not shapes:
|
||||
raise ValueError("Cannot infer MLP structure from state_dict")
|
||||
input_dim_inferred = int(shapes[0][1])
|
||||
hidden_dims_inferred = [int(s[0]) for s in shapes[:-1]] # all but last are hidden layer outputs
|
||||
num_classes_inferred = int(shapes[-1][0])
|
||||
return input_dim_inferred, tuple(hidden_dims_inferred), num_classes_inferred
|
||||
|
||||
# instantiate model
|
||||
if model_class == MLP:
|
||||
# prefer values from model_config
|
||||
cfg_input_dim = model_config.get('input_dim')
|
||||
cfg_hidden = model_config.get('hidden_dims') or model_config.get('hidden_dim') or model_config.get('hidden')
|
||||
cfg_num_classes = model_config.get('num_classes')
|
||||
|
||||
use_input_dim = input_dim or cfg_input_dim
|
||||
use_hidden = cfg_hidden
|
||||
use_num_classes = cfg_num_classes
|
||||
|
||||
if use_input_dim is None or use_num_classes is None:
|
||||
# infer from state
|
||||
inferred_input, inferred_hidden, inferred_num = _infer_mlp_from_state(state)
|
||||
if use_input_dim is None:
|
||||
use_input_dim = inferred_input
|
||||
if use_hidden is None:
|
||||
use_hidden = inferred_hidden
|
||||
if use_num_classes is None:
|
||||
use_num_classes = inferred_num
|
||||
|
||||
# normalize hidden dims to tuple if needed
|
||||
if use_hidden is None:
|
||||
use_hidden = (256, 128)
|
||||
elif isinstance(use_hidden, (list, tuple)):
|
||||
use_hidden = tuple(use_hidden)
|
||||
else:
|
||||
# sometimes stored as string
|
||||
try:
|
||||
use_hidden = tuple(int(x) for x in str(use_hidden).strip('()[]').split(',') if x)
|
||||
except Exception:
|
||||
use_hidden = (256, 128)
|
||||
|
||||
model = MLP(input_dim=int(use_input_dim), hidden_dims=use_hidden, num_classes=int(use_num_classes))
|
||||
|
||||
elif model_class == SimpleCNN:
|
||||
# use model_config if present
|
||||
cfg_num_classes = model_config.get('num_classes') or 10
|
||||
cfg_input_size = model_config.get('input_size') or (3, 224, 224)
|
||||
model = SimpleCNN(in_channels=cfg_input_size[0], num_classes=int(cfg_num_classes), input_size=tuple(cfg_input_size))
|
||||
else:
|
||||
raise ValueError(f"Unsupported model class: {model_class}")
|
||||
|
||||
# load weights into model
|
||||
try:
|
||||
model.load_state_dict(state)
|
||||
except Exception as e:
|
||||
# provide helpful diagnostics
|
||||
model_keys = list(model.state_dict().keys())[:50]
|
||||
state_keys = list(state.keys())[:50]
|
||||
raise RuntimeError(f"Failed to load state_dict: {e}. model_keys_sample={model_keys}, state_keys_sample={state_keys}")
|
||||
|
||||
return model
|
||||
|
||||
# Helper function to create different types of models
|
||||
def create_model(device=None, in_channels=3, num_classes=10, input_size=(3, 224, 224), model_type='cnn', input_dim=None, hidden_dims=None):
|
||||
"""
|
||||
Creates and returns a model based on the provided configuration.
|
||||
|
||||
Args:
|
||||
device (str or torch.device, optional): The device to run the model on ('cpu' or 'cuda').
|
||||
in_channels (int, optional): The number of input channels (default 3 for RGB images).
|
||||
num_classes (int, optional): The number of output classes (default 10).
|
||||
input_size (tuple, optional): The input size for the model (default (3, 224, 224)).
|
||||
model_type (str, optional): The type of model ('cnn' for convolutional, 'mlp' for multi-layer perceptron).
|
||||
input_dim (int, optional): The input dimension for the MLP (used only if `model_type == 'mlp'`).
|
||||
hidden_dims (tuple, optional): The dimensions of hidden layers for the MLP (used only if `model_type == 'mlp'`).
|
||||
|
||||
Returns:
|
||||
model (nn.Module): The created model.
|
||||
"""
|
||||
if model_type == 'mlp':
|
||||
if input_dim is None:
|
||||
raise ValueError('input_dim is required for mlp model_type')
|
||||
if hidden_dims is None:
|
||||
model = MLP(input_dim=input_dim, num_classes=num_classes)
|
||||
else:
|
||||
model = MLP(input_dim=input_dim, hidden_dims=hidden_dims, num_classes=num_classes)
|
||||
model = MLP(input_dim=input_dim, hidden_dims=hidden_dims or (256, 128), num_classes=num_classes)
|
||||
else:
|
||||
model = SimpleCNN(in_channels=in_channels, num_classes=num_classes, input_size=input_size)
|
||||
|
||||
if device:
|
||||
model.to(device)
|
||||
return model
|
||||
|
||||
|
||||
# Example for using load_model and create_model in the codebase:
|
||||
|
||||
# Loading a model
|
||||
# model = load_model('path_to_model.pth', SimpleCNN, device='cuda')
|
||||
|
||||
# Creating a model for inference
|
||||
# model = create_model(device='cuda', model_type='cnn', num_classes=5)
|
||||
@@ -294,41 +294,41 @@ def compute_reroute(
|
||||
"grid_shape": (n_lat, n_lon)
|
||||
}
|
||||
|
||||
def compute_index(lat: float,
|
||||
lon: float,
|
||||
max_risk: float = 10.0,
|
||||
num_bins: int = 10,
|
||||
risk_provider: Optional[Callable[[float, float], float]] = None) -> int:
|
||||
"""
|
||||
Computes and returns an index based on road risk and accident information.
|
||||
# def compute_index(lat: float,
|
||||
# lon: float,
|
||||
# max_risk: float = 10.0,
|
||||
# num_bins: int = 10,
|
||||
# risk_provider: Optional[Callable[[float, float], float]] = None) -> int:
|
||||
# """
|
||||
# Computes and returns an index based on road risk and accident information.
|
||||
|
||||
Args:
|
||||
lat: Latitude of the location.
|
||||
lon: Longitude of the location.
|
||||
max_risk: Maximum possible road risk score.
|
||||
num_bins: Number of bins to divide the risk range into.
|
||||
reroute_kwargs: Optional dictionary passed to reroute logic.
|
||||
risk_provider: Optional custom risk provider function.
|
||||
# Args:
|
||||
# lat: Latitude of the location.
|
||||
# lon: Longitude of the location.
|
||||
# max_risk: Maximum possible road risk score.
|
||||
# num_bins: Number of bins to divide the risk range into.
|
||||
# reroute_kwargs: Optional dictionary passed to reroute logic.
|
||||
# risk_provider: Optional custom risk provider function.
|
||||
|
||||
Returns:
|
||||
An integer index (1 to num_bins) based on computed road risk.
|
||||
"""
|
||||
# If a risk provider is not provided, use a default one
|
||||
if risk_provider is None:
|
||||
risk_provider = lambda lat, lon: get_risk_score(lat, lon)
|
||||
# Returns:
|
||||
# An integer index (1 to num_bins) based on computed road risk.
|
||||
# """
|
||||
# # If a risk provider is not provided, use a default one
|
||||
# if risk_provider is None:
|
||||
# risk_provider = lambda lat, lon: get_risk_score(lat, lon)
|
||||
|
||||
# Fetch road risk score using the provided risk provider
|
||||
road_risk = float(risk_provider(lat, lon))
|
||||
# # Fetch road risk score using the provided risk provider
|
||||
# road_risk = float(risk_provider(lat, lon))
|
||||
|
||||
# Compute the index based on road risk score and the max_risk, num_bins parameters
|
||||
# The formula will divide the risk score into `num_bins` bins
|
||||
# The index will be a number between 1 and num_bins based on the risk score
|
||||
# # Compute the index based on road risk score and the max_risk, num_bins parameters
|
||||
# # The formula will divide the risk score into `num_bins` bins
|
||||
# # The index will be a number between 1 and num_bins based on the risk score
|
||||
|
||||
# Normalize the risk score to be between 0 and max_risk
|
||||
normalized_risk = min(road_risk, max_risk) / max_risk
|
||||
# # Normalize the risk score to be between 0 and max_risk
|
||||
# normalized_risk = min(road_risk, max_risk) / max_risk
|
||||
|
||||
# Compute the index based on the normalized risk score
|
||||
index = int(normalized_risk * num_bins)
|
||||
# # Compute the index based on the normalized risk score
|
||||
# index = int(normalized_risk * num_bins)
|
||||
|
||||
# Ensure the index is within the expected range of 1 to num_bins
|
||||
return max(1, min(index + 1, num_bins)) # Adding 1 because index is 0-based
|
||||
# # Ensure the index is within the expected range of 1 to num_bins
|
||||
# return max(1, min(index + 1, num_bins)) # Adding 1 because index is 0-based
|
||||
@@ -1,342 +0,0 @@
|
||||
"""OpenWeather / Road Risk client.
|
||||
|
||||
Provides:
|
||||
- fetch_weather(lat, lon, api_key=None)
|
||||
- fetch_road_risk(lat, lon, api_key=None, roadrisk_url=None, extra_params=None)
|
||||
|
||||
Never hardcode API keys in source. Provide via api_key argument or set OPENWEATHER_API_KEY / OPENWEATHER_KEY env var.
|
||||
"""
|
||||
import os
|
||||
from typing import Tuple, Dict, Any, Optional, Callable, List
|
||||
import requests
|
||||
import heapq
|
||||
import math
|
||||
|
||||
def _get_api_key(explicit_key: Optional[str] = None) -> Optional[str]:
|
||||
if explicit_key:
|
||||
return explicit_key
|
||||
return os.environ.get("OPENWEATHER_API_KEY") or os.environ.get("OPENWEATHER_KEY")
|
||||
|
||||
BASE_URL = "https://api.openweathermap.org/data/2.5"
|
||||
|
||||
|
||||
def fetch_weather(lat: float, lon: float, params: Optional[dict] = None, api_key: Optional[str] = None) -> dict:
|
||||
"""Call standard OpenWeather /weather endpoint and return parsed JSON."""
|
||||
key = _get_api_key(api_key)
|
||||
if key is None:
|
||||
raise RuntimeError("Set OPENWEATHER_API_KEY or OPENWEATHER_KEY or pass api_key")
|
||||
q = {"lat": lat, "lon": lon, "appid": key, "units": "metric"}
|
||||
if params:
|
||||
q.update(params)
|
||||
resp = requests.get(f"{BASE_URL}/weather", params=q, timeout=10)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
|
||||
|
||||
def fetch_road_risk(lat: float, lon: float, extra_params: Optional[dict] = None, api_key: Optional[str] = None, roadrisk_url: Optional[str] = None) -> Tuple[dict, Dict[str, Any]]:
|
||||
"""
|
||||
Call OpenWeather /roadrisk endpoint (or provided roadrisk_url) and return (raw_json, features).
|
||||
|
||||
features will always include 'road_risk_score' (float). Other numeric fields are included when present.
|
||||
The implementation:
|
||||
- prefers explicit numeric keys (road_risk_score, risk_score, score, risk)
|
||||
- if absent, collects top-level numeric fields and averages common contributors
|
||||
- if still absent, falls back to a simple weather-derived heuristic using /weather
|
||||
|
||||
Note: Do not commit API keys. Pass api_key or set env var.
|
||||
"""
|
||||
key = _get_api_key(api_key)
|
||||
if key is None:
|
||||
raise RuntimeError("Set OPENWEATHER_API_KEY or OPENWEATHER_KEY or pass api_key")
|
||||
|
||||
params = {"lat": lat, "lon": lon, "appid": key}
|
||||
if extra_params:
|
||||
params.update(extra_params)
|
||||
|
||||
url = roadrisk_url or f"{BASE_URL}/roadrisk"
|
||||
resp = requests.get(url, params=params, timeout=10)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
|
||||
features: Dict[str, Any] = {}
|
||||
risk: Optional[float] = None
|
||||
|
||||
# direct candidates
|
||||
for candidate in ("road_risk_score", "risk_score", "risk", "score"):
|
||||
if isinstance(data, dict) and candidate in data:
|
||||
try:
|
||||
risk = float(data[candidate])
|
||||
features[candidate] = risk
|
||||
break
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# if no direct candidate, collect numeric top-level fields
|
||||
if risk is None and isinstance(data, dict):
|
||||
numeric_fields = {}
|
||||
for k, v in data.items():
|
||||
if isinstance(v, (int, float)):
|
||||
numeric_fields[k] = float(v)
|
||||
features.update(numeric_fields)
|
||||
# try averaging common contributors if present
|
||||
contributors = []
|
||||
for name in ("precipitation", "rain", "snow", "visibility", "wind_speed"):
|
||||
if name in data and isinstance(data[name], (int, float)):
|
||||
contributors.append(float(data[name]))
|
||||
if contributors:
|
||||
# average contributors -> risk proxy
|
||||
risk = float(sum(contributors) / len(contributors))
|
||||
|
||||
# fallback: derive crude risk from /weather
|
||||
if risk is None:
|
||||
try:
|
||||
w = fetch_weather(lat, lon, api_key=key)
|
||||
main = w.get("main", {})
|
||||
wind = w.get("wind", {})
|
||||
weather = w.get("weather", [{}])[0]
|
||||
# heuristic: rain + high wind + low visibility
|
||||
derived = 0.0
|
||||
if isinstance(weather.get("main", ""), str) and "rain" in weather.get("main", "").lower():
|
||||
derived += 1.0
|
||||
if (wind.get("speed") or 0) > 6.0:
|
||||
derived += 0.5
|
||||
if (w.get("visibility") or 10000) < 5000:
|
||||
derived += 1.0
|
||||
risk = float(derived)
|
||||
features.update({
|
||||
"temp": main.get("temp"),
|
||||
"humidity": main.get("humidity"),
|
||||
"wind_speed": wind.get("speed"),
|
||||
"visibility": w.get("visibility"),
|
||||
"weather_main": weather.get("main"),
|
||||
"weather_id": weather.get("id"),
|
||||
})
|
||||
except Exception:
|
||||
# cannot derive anything; set neutral 0.0
|
||||
risk = 0.0
|
||||
|
||||
features["road_risk_score"] = float(risk)
|
||||
return data, features
|
||||
|
||||
|
||||
def _haversine_km(a_lat: float, a_lon: float, b_lat: float, b_lon: float) -> float:
|
||||
# returns distance in kilometers
|
||||
R = 6371.0
|
||||
lat1, lon1, lat2, lon2 = map(math.radians, (a_lat, a_lon, b_lat, b_lon))
|
||||
dlat = lat2 - lat1
|
||||
dlon = lon2 - lon1
|
||||
h = math.sin(dlat / 2) ** 2 + math.cos(lat1) * math.cos(lat2) * math.sin(dlon / 2) ** 2
|
||||
return 2 * R * math.asin(min(1.0, math.sqrt(h)))
|
||||
|
||||
|
||||
def risk_to_index(risk_score: float, max_risk: float = 10.0, num_bins: int = 10) -> int:
|
||||
"""
|
||||
Map a numeric risk_score to an integer index 1..num_bins (higher => more risky).
|
||||
Uses equal-width bins: 0..(max_risk/num_bins) -> 1, ..., >=max_risk -> num_bins.
|
||||
"""
|
||||
if risk_score is None:
|
||||
return 1
|
||||
r = float(risk_score)
|
||||
if r <= 0:
|
||||
return 1
|
||||
if r >= max_risk:
|
||||
return num_bins
|
||||
bin_width = max_risk / float(num_bins)
|
||||
return int(r // bin_width) + 1
|
||||
|
||||
|
||||
def get_risk_score(lat: float, lon: float, **fetch_kwargs) -> float:
|
||||
"""
|
||||
Wrapper: calls fetch_road_risk and returns features['road_risk_score'] (float).
|
||||
Pass api_key/roadrisk_url via fetch_kwargs as needed.
|
||||
"""
|
||||
_, features = fetch_road_risk(lat, lon, **fetch_kwargs)
|
||||
return float(features.get("road_risk_score", 0.0))
|
||||
|
||||
|
||||
def compute_reroute(start_lat: float,
|
||||
start_lon: float,
|
||||
risk_provider: Callable[[float, float], float] = None,
|
||||
lat_range: float = 0.005,
|
||||
lon_range: float = 0.01,
|
||||
n_lat: int = 7,
|
||||
n_lon: int = 7,
|
||||
max_calls: Optional[int] = None,
|
||||
distance_weight: float = 0.1) -> Dict[str, Any]:
|
||||
"""
|
||||
Sample a grid around (start_lat, start_lon), get risk at each grid node via risk_provider,
|
||||
find the node with minimum risk, and run Dijkstra on the grid (4-neighbors) where edge cost =
|
||||
average node risk + distance_weight * distance_km. Returns path and stats.
|
||||
|
||||
Defaults: n_lat/n_lon small to limit API calls. max_calls optionally caps number of risk_provider calls.
|
||||
"""
|
||||
if risk_provider is None:
|
||||
# default risk provider that calls fetch_road_risk (may require API key in env or fetch_kwargs)
|
||||
def _rp(lat, lon): return get_risk_score(lat, lon)
|
||||
risk_provider = _rp
|
||||
|
||||
# build grid coordinates
|
||||
lat_steps = n_lat
|
||||
lon_steps = n_lon
|
||||
if lat_steps < 2 or lon_steps < 2:
|
||||
raise ValueError("n_lat and n_lon must be >= 2")
|
||||
lat0 = start_lat - lat_range
|
||||
lon0 = start_lon - lon_range
|
||||
lat_step = (2 * lat_range) / (lat_steps - 1)
|
||||
lon_step = (2 * lon_range) / (lon_steps - 1)
|
||||
|
||||
coords: List[Tuple[float, float]] = []
|
||||
for i in range(lat_steps):
|
||||
for j in range(lon_steps):
|
||||
coords.append((lat0 + i * lat_step, lon0 + j * lon_step))
|
||||
|
||||
# sample risks with caching and optional call limit
|
||||
risks: List[float] = []
|
||||
calls = 0
|
||||
for (lat, lon) in coords:
|
||||
if max_calls is not None and calls >= max_calls:
|
||||
# conservative fallback: assume same as start risk if call limit reached
|
||||
risks.append(float('inf'))
|
||||
continue
|
||||
try:
|
||||
r = float(risk_provider(lat, lon))
|
||||
except Exception:
|
||||
r = float('inf')
|
||||
risks.append(r)
|
||||
calls += 1
|
||||
|
||||
# convert to grid indexed by (i,j)
|
||||
def idx(i, j): return i * lon_steps + j
|
||||
# find start index (closest grid node to start)
|
||||
start_i = round((start_lat - lat0) / lat_step)
|
||||
start_j = round((start_lon - lon0) / lon_step)
|
||||
start_i = max(0, min(lat_steps - 1, start_i))
|
||||
start_j = max(0, min(lon_steps - 1, start_j))
|
||||
start_index = idx(start_i, start_j)
|
||||
|
||||
# find target node = min risk node (ignore inf)
|
||||
min_risk = min(risks)
|
||||
if math.isinf(min_risk) or min_risk >= risks[start_index]:
|
||||
# no better location found or sampling failed
|
||||
return {
|
||||
"reroute_needed": False,
|
||||
"reason": "no_lower_risk_found",
|
||||
"start_coord": (start_lat, start_lon),
|
||||
"start_risk": None if math.isinf(risks[start_index]) else risks[start_index],
|
||||
}
|
||||
|
||||
target_index = int(risks.index(min_risk))
|
||||
|
||||
# Dijkstra from start_index to target_index
|
||||
N = len(coords)
|
||||
dist = [math.inf] * N
|
||||
prev = [None] * N
|
||||
dist[start_index] = 0.0
|
||||
pq = [(0.0, start_index)]
|
||||
while pq:
|
||||
d, u = heapq.heappop(pq)
|
||||
if d > dist[u]:
|
||||
continue
|
||||
if u == target_index:
|
||||
break
|
||||
ui = u // lon_steps
|
||||
uj = u % lon_steps
|
||||
for di, dj in ((1,0),(-1,0),(0,1),(0,-1)):
|
||||
vi, vj = ui + di, uj + dj
|
||||
if 0 <= vi < lat_steps and 0 <= vj < lon_steps:
|
||||
v = idx(vi, vj)
|
||||
# cost: average node risk + small distance penalty
|
||||
ru = risks[u]
|
||||
rv = risks[v]
|
||||
if math.isinf(ru) or math.isinf(rv):
|
||||
continue
|
||||
lat_u, lon_u = coords[u]
|
||||
lat_v, lon_v = coords[v]
|
||||
d_km = _haversine_km(lat_u, lon_u, lat_v, lon_v)
|
||||
w = (ru + rv) / 2.0 + distance_weight * d_km
|
||||
nd = d + w
|
||||
if nd < dist[v]:
|
||||
dist[v] = nd
|
||||
prev[v] = u
|
||||
heapq.heappush(pq, (nd, v))
|
||||
|
||||
if math.isinf(dist[target_index]):
|
||||
return {
|
||||
"reroute_needed": False,
|
||||
"reason": "no_path_found",
|
||||
"start_coord": (start_lat, start_lon),
|
||||
"start_risk": risks[start_index],
|
||||
"target_risk": risks[target_index],
|
||||
}
|
||||
|
||||
# reconstruct path
|
||||
path_indices = []
|
||||
cur = target_index
|
||||
while cur is not None:
|
||||
path_indices.append(cur)
|
||||
cur = prev[cur]
|
||||
path_indices.reverse()
|
||||
path_coords = [coords[k] for k in path_indices]
|
||||
start_risk = risks[start_index]
|
||||
end_risk = risks[target_index]
|
||||
improvement = (start_risk - end_risk) if start_risk not in (None, float('inf')) else None
|
||||
|
||||
return {
|
||||
"reroute_needed": True,
|
||||
"start_coord": (start_lat, start_lon),
|
||||
"start_risk": start_risk,
|
||||
"target_coord": coords[target_index],
|
||||
"target_risk": end_risk,
|
||||
"path": path_coords,
|
||||
"path_cost": dist[target_index],
|
||||
"risk_improvement": improvement,
|
||||
"grid_shape": (lat_steps, lon_steps),
|
||||
"calls_made": calls,
|
||||
}
|
||||
|
||||
|
||||
def compute_index_and_reroute(lat: float,
|
||||
lon: float,
|
||||
api_key: Optional[str] = None,
|
||||
roadrisk_url: Optional[str] = None,
|
||||
max_risk: float = 10.0,
|
||||
num_bins: int = 10,
|
||||
reroute_kwargs: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
High-level convenience: get road risk, map to index (1..num_bins), and attempt reroute.
|
||||
reroute_kwargs are forwarded to compute_reroute (risk_provider will call fetch_road_risk
|
||||
using provided api_key/roadrisk_url).
|
||||
"""
|
||||
if reroute_kwargs is None:
|
||||
reroute_kwargs = {}
|
||||
|
||||
# obtain base risk
|
||||
data, features = fetch_road_risk(lat, lon, api_key=api_key, roadrisk_url=roadrisk_url)
|
||||
road_risk = float(features.get("road_risk_score", 0.0))
|
||||
|
||||
# compute index: if 'accidents' present in features, prefer that mapping
|
||||
accidents = features.get("accidents") or features.get("accident_count")
|
||||
try:
|
||||
if accidents is not None:
|
||||
# map raw accident count to index 1..num_bins
|
||||
from .models import accidents_to_bucket
|
||||
idx = accidents_to_bucket(int(accidents), max_count=20000, num_bins=num_bins)
|
||||
else:
|
||||
idx = risk_to_index(road_risk, max_risk=max_risk, num_bins=num_bins)
|
||||
except Exception:
|
||||
idx = risk_to_index(road_risk, max_risk=max_risk, num_bins=num_bins)
|
||||
|
||||
# prepare risk_provider that passes api_key/roadrisk_url through
|
||||
def _rp(lat_, lon_):
|
||||
return get_risk_score(lat_, lon_, api_key=api_key, roadrisk_url=roadrisk_url)
|
||||
|
||||
reroute_info = compute_reroute(lat, lon, risk_provider=_rp, **reroute_kwargs)
|
||||
return {
|
||||
"lat": lat,
|
||||
"lon": lon,
|
||||
"index": int(idx),
|
||||
"road_risk_score": road_risk,
|
||||
"features": features,
|
||||
"reroute": reroute_info,
|
||||
"raw_roadrisk_response": data,
|
||||
}
|
||||
@@ -1,339 +0,0 @@
|
||||
"""
|
||||
Fetch OpenWeather data for a coordinate/time and run the trained MLP to predict the k-means cluster label.
|
||||
|
||||
Usage examples:
|
||||
# with training CSV provided to compute preprocessing stats:
|
||||
python openweather_inference.py --lat 38.9 --lon -77.0 --datetime "2025-09-27T12:00:00" --train-csv data.csv --model model.pth --centers kmeans_centers_all.npz --api-key $OPENWEATHER_KEY
|
||||
|
||||
# with precomputed preprocess meta (saved from training):
|
||||
python openweather_inference.py --lat 38.9 --lon -77.0 --datetime "2025-09-27T12:00:00" --preprocess-meta preprocess_meta.npz --model model.pth --centers kmeans_centers_all.npz --api-key $OPENWEATHER_KEY
|
||||
|
||||
Notes:
|
||||
- The script uses the same feature-engineering helpers in `data.py` so the model sees identical inputs.
|
||||
- You must either provide `--train-csv` (to compute feature columns & means/stds) or `--preprocess-meta` previously saved.
|
||||
- Provide the OpenWeather API key via --api-key or the OPENWEATHER_KEY environment variable.
|
||||
"""
|
||||
|
||||
import os
|
||||
import argparse
|
||||
import json
|
||||
from datetime import datetime
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
# reuse helpers from your repo
|
||||
from data import _add_date_features, _add_latlon_bins, _add_hashed_street, CSVDataset
|
||||
from inference import load_model
|
||||
|
||||
# module-level caches to avoid reloading heavy artifacts per request
|
||||
_CACHED_MODEL = None
|
||||
_CACHED_IDX_TO_CLASS = None
|
||||
_CACHED_CENTERS = None
|
||||
_CACHED_PREPROCESS_META = None
|
||||
|
||||
|
||||
OW_BASE = 'https://api.openweathermap.org/data/2.5/onecall'
|
||||
|
||||
|
||||
def fetch_openweather(lat, lon, api_key, dt_iso=None):
|
||||
"""Fetch weather from OpenWeather One Call API for given lat/lon. If dt_iso provided, we fetch current+hourly and pick closest timestamp."""
|
||||
try:
|
||||
import requests
|
||||
except Exception:
|
||||
raise RuntimeError('requests library is required to fetch OpenWeather data')
|
||||
params = {
|
||||
'lat': float(lat),
|
||||
'lon': float(lon),
|
||||
'appid': api_key,
|
||||
'units': 'metric',
|
||||
'exclude': 'minutely,alerts'
|
||||
}
|
||||
r = requests.get(OW_BASE, params=params, timeout=10)
|
||||
r.raise_for_status()
|
||||
payload = r.json()
|
||||
# if dt_iso provided, find nearest hourly data point
|
||||
if dt_iso:
|
||||
try:
|
||||
target = pd.to_datetime(dt_iso)
|
||||
except Exception:
|
||||
target = None
|
||||
best = None
|
||||
if 'hourly' in payload and target is not None:
|
||||
hours = payload['hourly']
|
||||
best = min(hours, key=lambda h: abs(pd.to_datetime(h['dt'], unit='s') - target))
|
||||
# convert keys to a flat dict with prefix 'ow_'
|
||||
d = {
|
||||
'ow_temp': best.get('temp'),
|
||||
'ow_feels_like': best.get('feels_like'),
|
||||
'ow_pressure': best.get('pressure'),
|
||||
'ow_humidity': best.get('humidity'),
|
||||
'ow_wind_speed': best.get('wind_speed'),
|
||||
'ow_clouds': best.get('clouds'),
|
||||
'ow_pop': best.get('pop'),
|
||||
}
|
||||
return d
|
||||
# fallback: use current
|
||||
cur = payload.get('current', {})
|
||||
d = {
|
||||
'ow_temp': cur.get('temp'),
|
||||
'ow_feels_like': cur.get('feels_like'),
|
||||
'ow_pressure': cur.get('pressure'),
|
||||
'ow_humidity': cur.get('humidity'),
|
||||
'ow_wind_speed': cur.get('wind_speed'),
|
||||
'ow_clouds': cur.get('clouds'),
|
||||
'ow_pop': None,
|
||||
}
|
||||
return d
|
||||
|
||||
|
||||
def fetch_roadrisk(roadrisk_url, api_key=None):
|
||||
"""Fetch the RoadRisk endpoint (expects JSON). If `api_key` is provided, we'll attach it as a query param if the URL has no key.
|
||||
|
||||
We flatten top-level numeric fields into `rr_*` keys for the feature row.
|
||||
"""
|
||||
# if api_key provided and url does not contain appid, append it
|
||||
try:
|
||||
import requests
|
||||
except Exception:
|
||||
raise RuntimeError('requests library is required to fetch RoadRisk data')
|
||||
url = roadrisk_url
|
||||
if api_key and 'appid=' not in roadrisk_url:
|
||||
sep = '&' if '?' in roadrisk_url else '?'
|
||||
url = f"{roadrisk_url}{sep}appid={api_key}"
|
||||
|
||||
r = requests.get(url, timeout=10)
|
||||
r.raise_for_status()
|
||||
payload = r.json()
|
||||
# flatten numeric top-level fields
|
||||
out = {}
|
||||
if isinstance(payload, dict):
|
||||
for k, v in payload.items():
|
||||
if isinstance(v, (int, float)):
|
||||
out[f'rr_{k}'] = v
|
||||
# if nested objects contain simple numeric fields, pull them too (one level deep)
|
||||
elif isinstance(v, dict):
|
||||
for kk, vv in v.items():
|
||||
if isinstance(vv, (int, float)):
|
||||
out[f'rr_{k}_{kk}'] = vv
|
||||
return out
|
||||
|
||||
|
||||
def build_row(lat, lon, dt_iso=None, street=None, extra_weather=None):
|
||||
"""Construct a single-row DataFrame with columns expected by the training pipeline.
|
||||
|
||||
It intentionally uses column names the original `data.py` looked for (REPORTDATE, LATITUDE, LONGITUDE, ADDRESS, etc.).
|
||||
"""
|
||||
row = {}
|
||||
# date column matching common names
|
||||
row['REPORTDATE'] = dt_iso if dt_iso else datetime.utcnow().isoformat()
|
||||
row['LATITUDE'] = lat
|
||||
row['LONGITUDE'] = lon
|
||||
row['ADDRESS'] = street if street else ''
|
||||
# include some injury/fatality placeholders that the label generator expects
|
||||
row['INJURIES'] = 0
|
||||
row['FATALITIES'] = 0
|
||||
# include weather features returned by OpenWeather (prefixed 'ow_')
|
||||
if extra_weather:
|
||||
for k, v in extra_weather.items():
|
||||
row[k] = v
|
||||
return pd.DataFrame([row])
|
||||
|
||||
|
||||
def prepare_features(df_row, train_csv=None, preprocess_meta=None, feature_engineer=True, lat_lon_bins=20):
|
||||
"""Given a one-row DataFrame, apply same feature engineering and standardization as training.
|
||||
|
||||
If preprocess_meta is provided (npz), use it. Otherwise train_csv must be provided to compute stats.
|
||||
Returns a torch.FloatTensor of shape (1, input_dim) and the feature_columns list.
|
||||
"""
|
||||
# apply feature engineering helpers
|
||||
if feature_engineer:
|
||||
try:
|
||||
_add_date_features(df_row)
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
_add_latlon_bins(df_row, bins=lat_lon_bins)
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
_add_hashed_street(df_row)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# if meta provided, load feature_columns, means, stds
|
||||
if preprocess_meta and os.path.exists(preprocess_meta):
|
||||
meta = np.load(preprocess_meta, allow_pickle=True)
|
||||
feature_columns = meta['feature_columns'].tolist()
|
||||
means = meta['means']
|
||||
stds = meta['stds']
|
||||
else:
|
||||
if not train_csv:
|
||||
raise ValueError('Either preprocess_meta or train_csv must be provided to derive feature stats')
|
||||
# instantiate a CSVDataset on train_csv (feature_engineer True) to reuse its preprocessing
|
||||
ds = CSVDataset(train_csv, feature_columns=None, label_column='label', generate_labels=False, n_buckets=10, label_method='kmeans', label_store=None, feature_engineer=feature_engineer, lat_lon_bins=lat_lon_bins, nrows=None)
|
||||
feature_columns = ds.feature_columns
|
||||
means = ds.feature_means
|
||||
stds = ds.feature_stds
|
||||
# save meta for reuse
|
||||
np.savez_compressed('preprocess_meta.npz', feature_columns=np.array(feature_columns, dtype=object), means=means, stds=stds)
|
||||
print('Saved preprocess_meta.npz')
|
||||
|
||||
# ensure all feature columns exist in df_row
|
||||
for c in feature_columns:
|
||||
if c not in df_row.columns:
|
||||
df_row[c] = 0
|
||||
|
||||
# coerce and fill using means
|
||||
features_df = df_row[feature_columns].apply(lambda c: pd.to_numeric(c, errors='coerce'))
|
||||
features_df = features_df.fillna(pd.Series(means, index=feature_columns)).fillna(0.0)
|
||||
# standardize
|
||||
features_np = (features_df.values - means) / (stds + 1e-6)
|
||||
import torch
|
||||
return torch.tensor(features_np, dtype=torch.float32), feature_columns
|
||||
|
||||
|
||||
def predict_from_openweather(lat, lon, dt_iso=None, street=None, api_key=None, train_csv=None, preprocess_meta=None, model_path='model.pth', centers_path='kmeans_centers_all.npz', roadrisk_url=None):
|
||||
api_key = api_key or os.environ.get('OPENWEATHER_KEY')
|
||||
if api_key is None:
|
||||
raise ValueError('OpenWeather API key required via --api-key or OPENWEATHER_KEY env var')
|
||||
|
||||
# gather weather/road-risk features
|
||||
weather = {}
|
||||
if roadrisk_url:
|
||||
try:
|
||||
rr = fetch_roadrisk(roadrisk_url, api_key=api_key)
|
||||
weather.update(rr)
|
||||
except Exception as e:
|
||||
print('Warning: failed to fetch roadrisk URL:', e)
|
||||
else:
|
||||
try:
|
||||
ow = fetch_openweather(lat, lon, api_key, dt_iso=dt_iso)
|
||||
weather.update(ow)
|
||||
except Exception as e:
|
||||
print('Warning: failed to fetch openweather:', e)
|
||||
|
||||
df_row = build_row(lat, lon, dt_iso=dt_iso, street=street, extra_weather=weather)
|
||||
x_tensor, feature_columns = prepare_features(df_row, train_csv=train_csv, preprocess_meta=preprocess_meta)
|
||||
|
||||
# load model (infer num_classes from centers file if possible)
|
||||
global _CACHED_MODEL, _CACHED_IDX_TO_CLASS, _CACHED_CENTERS, _CACHED_PREPROCESS_META
|
||||
|
||||
# ensure we have preprocess_meta available (prefer supplied path, otherwise fallback to saved file)
|
||||
if preprocess_meta is None:
|
||||
candidate = os.path.join(os.getcwd(), 'preprocess_meta.npz')
|
||||
if os.path.exists(candidate):
|
||||
preprocess_meta = candidate
|
||||
|
||||
# load centers (cache across requests)
|
||||
if _CACHED_CENTERS is None:
|
||||
if centers_path and os.path.exists(centers_path):
|
||||
try:
|
||||
npz = np.load(centers_path)
|
||||
_CACHED_CENTERS = npz['centers']
|
||||
except Exception:
|
||||
_CACHED_CENTERS = None
|
||||
else:
|
||||
_CACHED_CENTERS = None
|
||||
|
||||
num_classes = _CACHED_CENTERS.shape[0] if _CACHED_CENTERS is not None else 10
|
||||
|
||||
# load model once and cache it
|
||||
if _CACHED_MODEL is None:
|
||||
try:
|
||||
_CACHED_MODEL, _CACHED_IDX_TO_CLASS = load_model(model_path, device=None, in_channels=3, num_classes=num_classes)
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
_CACHED_MODEL.to(device)
|
||||
except Exception as e:
|
||||
raise
|
||||
model = _CACHED_MODEL
|
||||
idx_to_class = _CACHED_IDX_TO_CLASS
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
x_tensor = x_tensor.to(device)
|
||||
with torch.no_grad():
|
||||
logits = model(x_tensor)
|
||||
probs = F.softmax(logits, dim=1).cpu().numpy()[0]
|
||||
pred_idx = int(probs.argmax())
|
||||
confidence = float(probs.max())
|
||||
|
||||
# optionally provide cluster centroid info
|
||||
centroid = _CACHED_CENTERS[pred_idx] if _CACHED_CENTERS is not None else None
|
||||
|
||||
return {
|
||||
'pred_cluster': int(pred_idx),
|
||||
'confidence': confidence,
|
||||
'probabilities': probs.tolist(),
|
||||
'centroid': centroid.tolist() if centroid is not None else None,
|
||||
'feature_columns': feature_columns,
|
||||
'used_preprocess_meta': preprocess_meta
|
||||
}
|
||||
|
||||
|
||||
def init_inference(model_path='model.pth', centers_path='kmeans_centers_all.npz', preprocess_meta=None):
|
||||
"""Eagerly load model, centers, and preprocess_meta into module-level caches.
|
||||
|
||||
This is intended to be called at app startup to surface load errors early and avoid
|
||||
per-request disk IO. The function is best-effort and will print warnings if artifacts
|
||||
are missing.
|
||||
"""
|
||||
global _CACHED_MODEL, _CACHED_IDX_TO_CLASS, _CACHED_CENTERS, _CACHED_PREPROCESS_META
|
||||
|
||||
# prefer existing saved preprocess_meta if not explicitly provided
|
||||
if preprocess_meta is None:
|
||||
candidate = os.path.join(os.getcwd(), 'preprocess_meta.npz')
|
||||
if os.path.exists(candidate):
|
||||
preprocess_meta = candidate
|
||||
|
||||
_CACHED_PREPROCESS_META = preprocess_meta
|
||||
|
||||
# load centers
|
||||
if _CACHED_CENTERS is None:
|
||||
if centers_path and os.path.exists(centers_path):
|
||||
try:
|
||||
npz = np.load(centers_path)
|
||||
_CACHED_CENTERS = npz['centers']
|
||||
print(f'Loaded centers from {centers_path}')
|
||||
except Exception as e:
|
||||
print('Warning: failed to load centers:', e)
|
||||
_CACHED_CENTERS = None
|
||||
else:
|
||||
print('No centers file found at', centers_path)
|
||||
_CACHED_CENTERS = None
|
||||
|
||||
num_classes = _CACHED_CENTERS.shape[0] if _CACHED_CENTERS is not None else 10
|
||||
|
||||
# load model
|
||||
if _CACHED_MODEL is None:
|
||||
try:
|
||||
_CACHED_MODEL, _CACHED_IDX_TO_CLASS = load_model(model_path, device=None, in_channels=3, num_classes=num_classes)
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
_CACHED_MODEL.to(device)
|
||||
print(f'Loaded model from {model_path}')
|
||||
except Exception as e:
|
||||
print('Warning: failed to load model:', e)
|
||||
_CACHED_MODEL = None
|
||||
|
||||
return {
|
||||
'model_loaded': _CACHED_MODEL is not None,
|
||||
'centers_loaded': _CACHED_CENTERS is not None,
|
||||
'preprocess_meta': _CACHED_PREPROCESS_META
|
||||
}
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--lat', type=float, required=True)
|
||||
parser.add_argument('--lon', type=float, required=True)
|
||||
parser.add_argument('--datetime', default=None, help='ISO datetime string to query hourly weather (optional)')
|
||||
parser.add_argument('--street', default='')
|
||||
parser.add_argument('--api-key', default=None, help='OpenWeather API key or use OPENWEATHER_KEY env var')
|
||||
parser.add_argument('--train-csv', default=None, help='Path to training CSV to compute preprocessing stats (optional if --preprocess-meta provided)')
|
||||
parser.add_argument('--preprocess-meta', default=None, help='Path to precomputed preprocess_meta.npz (optional)')
|
||||
parser.add_argument('--model', default='model.pth')
|
||||
parser.add_argument('--centers', default='kmeans_centers_all.npz')
|
||||
parser.add_argument('--roadrisk-url', default=None, help='Optional custom RoadRisk API URL (if provided, will be queried instead of OneCall)')
|
||||
args = parser.parse_args()
|
||||
|
||||
out = predict_from_openweather(args.lat, args.lon, dt_iso=args.datetime, street=args.street, api_key=args.api_key, train_csv=args.train_csv, preprocess_meta=args.preprocess_meta, model_path=args.model, centers_path=args.centers, roadrisk_url=args.roadrisk_url)
|
||||
print(json.dumps(out, indent=2))
|
||||
@@ -1,6 +1,7 @@
|
||||
import os
|
||||
import time
|
||||
import torch
|
||||
import json
|
||||
from torch import nn, optim
|
||||
from torch.utils.data import DataLoader, random_split
|
||||
from tqdm import tqdm
|
||||
@@ -52,18 +53,106 @@ def train(dataset_root, epochs=3, batch_size=16, lr=1e-3, device=None, num_class
|
||||
print(f'Saved preprocess meta to {meta_path}')
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# ensure model_num_classes has a defined value for later use
|
||||
model_num_classes = None
|
||||
|
||||
# ---- new: generate kmeans labels when requested ----
|
||||
if generate_labels and label_method == 'kmeans':
|
||||
try:
|
||||
import numpy as _np
|
||||
# ensure features is numpy 2D array
|
||||
X = dataset.features
|
||||
if hasattr(X, 'toarray'):
|
||||
X = X.toarray()
|
||||
X = _np.asarray(X, dtype=float)
|
||||
|
||||
# basic preprocessing: fill NaN and scale (mean/std)
|
||||
nan_mask = _np.isnan(X)
|
||||
if nan_mask.any():
|
||||
col_means = _np.nanmean(X, axis=0)
|
||||
inds = _np.where(nan_mask)
|
||||
X[inds] = _np.take(col_means, inds[1])
|
||||
# standardize
|
||||
col_means = X.mean(axis=0)
|
||||
col_stds = X.std(axis=0)
|
||||
col_stds[col_stds == 0] = 1.0
|
||||
Xs = (X - col_means) / col_stds
|
||||
|
||||
# use sklearn KMeans
|
||||
try:
|
||||
from sklearn.cluster import KMeans
|
||||
except Exception as e:
|
||||
raise RuntimeError("sklearn is required for kmeans label generation: " + str(e))
|
||||
|
||||
n_clusters = 10 # produce 1..10 labels as required
|
||||
kmeans = KMeans(n_clusters=n_clusters, random_state=seed, n_init=10)
|
||||
cluster_ids = kmeans.fit_predict(Xs)
|
||||
|
||||
# compute a simple score per cluster to sort them (e.g., center mean)
|
||||
centers = kmeans.cluster_centers_
|
||||
center_scores = centers.mean(axis=1)
|
||||
# sort cluster ids by score -> map to rank 1..n_clusters (1 = lowest score)
|
||||
order = _np.argsort(center_scores)
|
||||
rank_map = {int(c): (int(_np.where(order == c)[0][0]) + 1) for c in range(len(order))}
|
||||
|
||||
# assign labels 1..10 based on cluster rank
|
||||
assigned_labels_1to10 = [_np.float64(rank_map[int(cid)]) for cid in cluster_ids]
|
||||
|
||||
# for training (classification) convert to 0..9 integer labels
|
||||
assigned_labels_zero_based = _np.array([int(lbl) - 1 for lbl in assigned_labels_1to10], dtype=int)
|
||||
|
||||
# attach to dataset (CSVDataset consumers expect .labels possibly)
|
||||
try:
|
||||
import torch as _torch
|
||||
dataset.labels = _torch.from_numpy(assigned_labels_zero_based).long()
|
||||
except Exception:
|
||||
# fallback to numpy attribute
|
||||
dataset.labels = assigned_labels_zero_based
|
||||
|
||||
# persist label_info / assignments
|
||||
label_info = {
|
||||
"generated": True,
|
||||
"label_method": "kmeans",
|
||||
"n_clusters": n_clusters,
|
||||
}
|
||||
try:
|
||||
# save assignments if small enough
|
||||
if len(assigned_labels_1to10) <= 100000:
|
||||
label_info["assignments"] = [float(x) for x in assigned_labels_1to10]
|
||||
else:
|
||||
arr_path = os.path.join(output_dir, "label_assignments.npz")
|
||||
_np.savez_compressed(arr_path, assignments=_np.array(assigned_labels_1to10))
|
||||
label_info["assignments_file"] = os.path.basename(arr_path)
|
||||
with open(os.path.join(output_dir, "label_info.json"), "w") as f:
|
||||
json.dump(label_info, f)
|
||||
except Exception:
|
||||
pass
|
||||
# update model_num_classes for training (10 clusters)
|
||||
model_num_classes = n_clusters
|
||||
print(f"Generated kmeans labels with {n_clusters} clusters; saved label_info.json")
|
||||
except Exception as e:
|
||||
print("KMeans label generation failed:", e)
|
||||
# fall back to prior logic (md5 or provided labels)
|
||||
# ---- end kmeans generation ----
|
||||
|
||||
if model_type == 'cnn':
|
||||
raise ValueError('CSV dataset should use model_type="mlp"')
|
||||
# if we generated labels, infer the actual number of classes from the dataset labels
|
||||
if generate_labels and hasattr(dataset, 'labels'):
|
||||
try:
|
||||
model_num_classes = int(dataset.labels.max().item()) + 1
|
||||
except Exception:
|
||||
model_num_classes = n_buckets
|
||||
else:
|
||||
model_num_classes = n_buckets if generate_labels else num_classes
|
||||
|
||||
# determine model_num_classes if not set by kmeans above
|
||||
if model_num_classes is None:
|
||||
# if we generated labels (non-kmeans) and dataset provides labels, infer number of classes
|
||||
if generate_labels and hasattr(dataset, 'labels') and label_method != 'kmeans':
|
||||
try:
|
||||
model_num_classes = int(dataset.labels.max().item()) + 1
|
||||
except Exception:
|
||||
model_num_classes = n_buckets
|
||||
else:
|
||||
# default behavior
|
||||
model_num_classes = n_buckets if generate_labels else num_classes
|
||||
|
||||
# If labels were generated, save label metadata + assignments (if not huge)
|
||||
if generate_labels:
|
||||
if generate_labels and label_method != 'kmeans':
|
||||
try:
|
||||
label_info = {
|
||||
"generated": True,
|
||||
@@ -86,7 +175,6 @@ def train(dataset_root, epochs=3, batch_size=16, lr=1e-3, device=None, num_class
|
||||
except Exception:
|
||||
pass
|
||||
with open(os.path.join(output_dir, "label_info.json"), "w") as f:
|
||||
import json
|
||||
json.dump(label_info, f)
|
||||
print(f"Saved label_info to {os.path.join(output_dir, 'label_info.json')}")
|
||||
except Exception:
|
||||
@@ -170,7 +258,6 @@ def train(dataset_root, epochs=3, batch_size=16, lr=1e-3, device=None, num_class
|
||||
|
||||
if __name__ == '__main__':
|
||||
import argparse
|
||||
import json
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('data_root')
|
||||
parser.add_argument('--epochs', type=int, default=3)
|
||||
@@ -210,3 +297,44 @@ if __name__ == '__main__':
|
||||
json.dump(label_info, f)
|
||||
train(data_root, epochs=args.epochs, batch_size=args.batch_size, lr=args.lr, model_type=args.model_type, csv_label=args.csv_label, generate_labels=args.generate_labels, n_buckets=args.n_buckets, label_method=args.label_method, label_store=args.label_store, feature_engineer=args.feature_engineer, lat_lon_bins=args.lat_lon_bins, nrows=nrows, seed=args.seed, hidden_dims=hidden_dims, weight_decay=args.weight_decay, output_dir=args.output_dir)
|
||||
|
||||
# ---------------- new helper ----------------
|
||||
def compute_index(model, feature_vector):
|
||||
"""
|
||||
Run model on a single feature_vector and return the model-provided index as float.
|
||||
- If model is a classifier (outputs logits), returns argmax + 1.0 (so labels 1..C).
|
||||
- If model returns a single scalar regression, returns that scalar as float.
|
||||
feature_vector may be numpy array or torch tensor (1D or 2D single sample).
|
||||
"""
|
||||
try:
|
||||
import torch
|
||||
model.eval()
|
||||
if not isinstance(feature_vector, torch.Tensor):
|
||||
fv = torch.tensor(feature_vector, dtype=torch.float32)
|
||||
else:
|
||||
fv = feature_vector.float()
|
||||
# ensure batch dim
|
||||
if fv.dim() == 1:
|
||||
fv = fv.unsqueeze(0)
|
||||
with torch.no_grad():
|
||||
out = model(fv)
|
||||
# if tensor output
|
||||
if hasattr(out, 'detach'):
|
||||
out_t = out.detach().cpu()
|
||||
if out_t.ndim == 2 and out_t.shape[1] > 1:
|
||||
# classifier logits/probs -> argmax
|
||||
idx = int(out_t.argmax(dim=1).item())
|
||||
return float(idx + 1)
|
||||
elif out_t.numel() == 1:
|
||||
return float(out_t.item())
|
||||
else:
|
||||
# fallback: return first element
|
||||
return float(out_t.flatten()[0].item())
|
||||
else:
|
||||
# not a tensor (unlikely), try float conversion
|
||||
return float(out)
|
||||
except Exception as e:
|
||||
raise RuntimeError("compute_index failed: " + str(e))
|
||||
return float(out)
|
||||
except Exception as e:
|
||||
raise RuntimeError("compute_index failed: " + str(e))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user