# import torch # import torch.nn as nn # import math # from typing import Union, Iterable # import numpy as np # import torch as _torch # def accidents_to_bucket(count: Union[int, float, Iterable], # max_count: int = 20000, # num_bins: int = 10) -> Union[int, list, _torch.Tensor, np.ndarray]: # """ # Map accident counts to simple buckets 1..num_bins (equal-width). # Example: max_count=20000, num_bins=10 -> bin width = 2000 # 0-1999 -> 1, 2000-3999 -> 2, ..., 18000-20000 -> 10 # Args: # count: single value or iterable (list/numpy/torch). Values <=0 map to 1, values >= max_count map to num_bins. # max_count: expected maximum count (top of highest bin). # num_bins: number of buckets (default 10). # Returns: # Same type as input (int for scalar, list/numpy/torch for iterables) with values in 1..num_bins. # """ # width = max_count / float(num_bins) # def _bucket_scalar(x): # # clamp # x = 0.0 if x is None else float(x) # if x <= 0: # return 1 # if x >= max_count: # return num_bins # return int(x // width) + 1 # # scalar int/float # if isinstance(count, (int, float)): # return _bucket_scalar(count) # # torch tensor # if isinstance(count, _torch.Tensor): # x = count.clone().float() # x = _torch.clamp(x, min=0.0, max=float(max_count)) # buckets = (x // width).to(_torch.long) + 1 # buckets = _torch.clamp(buckets, min=1, max=num_bins) # return buckets # # numpy array # if isinstance(count, np.ndarray): # x = np.clip(count.astype(float), 0.0, float(max_count)) # buckets = (x // width).astype(int) + 1 # return np.clip(buckets, 1, num_bins) # # generic iterable -> list # if isinstance(count, Iterable): # return [ _bucket_scalar(float(x)) for x in count ] # # fallback # return _bucket_scalar(float(count)) # class SimpleCNN(nn.Module): # """A small CNN for image classification (adjustable). Automatically computes flattened size.""" # def __init__(self, in_channels=3, num_classes=10, input_size=(3, 224, 224)): # super().__init__() # self.features = nn.Sequential( # nn.Conv2d(in_channels, 32, kernel_size=3, padding=1), # nn.ReLU(), # nn.MaxPool2d(2), # nn.Conv2d(32, 64, kernel_size=3, padding=1), # nn.ReLU(), # nn.MaxPool2d(2), # ) # # compute flatten size using a dummy tensor # with torch.no_grad(): # dummy = torch.zeros(1, *input_size) # feat = self.features(dummy) # # flat_features was previously computed as: # # int(feat.numel() / feat.shape[0]) # # Explanation: # # feat.shape == (N, C, H, W) (for image inputs) # # feat.numel() == N * C * H * W # # dividing by N (feat.shape[0]) yields C * H * W, i.e. flattened size per sample # # Clearer alternative using tensor shape: # flat_features = int(torch.prod(torch.tensor(feat.shape[1:])).item()) # # If you need the linear index mapping for coordinates (c, h, w): # # idx = c * (H * W) + h * W + w # self.classifier = nn.Sequential( # nn.Flatten(), # nn.Linear(flat_features, 256), # nn.ReLU(), # nn.Dropout(0.5), # nn.Linear(256, num_classes), # ) # def forward(self, x): # x = self.features(x) # x = self.classifier(x) # return x # class MLP(nn.Module): # """Simple MLP for tabular CSV data classification.""" # def __init__(self, input_dim, hidden_dims=(256, 128), num_classes=2): # super().__init__() # layers = [] # prev = input_dim # for h in hidden_dims: # layers.append(nn.Linear(prev, h)) # layers.append(nn.ReLU()) # layers.append(nn.Dropout(0.2)) # prev = h # layers.append(nn.Linear(prev, num_classes)) # self.net = nn.Sequential(*layers) # def forward(self, x): # return self.net(x) # def create_model(device=None, in_channels=3, num_classes=10, input_size=(3, 224, 224), model_type='cnn', input_dim=None, hidden_dims=None): # if model_type == 'mlp': # if input_dim is None: # raise ValueError('input_dim is required for mlp model_type') # if hidden_dims is None: # model = MLP(input_dim=input_dim, num_classes=num_classes) # else: # model = MLP(input_dim=input_dim, hidden_dims=hidden_dims, num_classes=num_classes) # else: # model = SimpleCNN(in_channels=in_channels, num_classes=num_classes, input_size=input_size) # if device: # model.to(device) # return model import torch import torch.nn as nn import math from typing import Union, Iterable import numpy as np import os # Retaining the existing `accidents_to_bucket` function for accident categorization def accidents_to_bucket(count: Union[int, float, Iterable], max_count: int = 20000, num_bins: int = 10) -> Union[int, list, torch.Tensor, np.ndarray]: """ Map accident counts to simple buckets 1..num_bins (equal-width). Example: max_count=20000, num_bins=10 -> bin width = 2000 0-1999 -> 1, 2000-3999 -> 2, ..., 18000-20000 -> 10 Args: count: single value or iterable (list/numpy/torch). Values <=0 map to 1, values >= max_count map to num_bins. max_count: expected maximum count (top of highest bin). num_bins: number of buckets (default 10). Returns: Same type as input (int for scalar, list/numpy/torch for iterables) with values in 1..num_bins. """ width = max_count / float(num_bins) def _bucket_scalar(x): # clamp x = 0.0 if x is None else float(x) if x <= 0: return 1 if x >= max_count: return num_bins return int(x // width) + 1 # scalar int/float if isinstance(count, (int, float)): return _bucket_scalar(count) # torch tensor if isinstance(count, torch.Tensor): x = count.clone().float() x = torch.clamp(x, min=0.0, max=float(max_count)) buckets = (x // width).to(torch.long) + 1 buckets = torch.clamp(buckets, min=1, max=num_bins) return buckets # numpy array if isinstance(count, np.ndarray): x = np.clip(count.astype(float), 0.0, float(max_count)) buckets = (x // width).astype(int) + 1 return np.clip(buckets, 1, num_bins) # generic iterable -> list if isinstance(count, Iterable): return [ _bucket_scalar(float(x)) for x in count ] # fallback return _bucket_scalar(float(count)) # SimpleCNN: CNN model for image classification class SimpleCNN(nn.Module): def __init__(self, in_channels=3, num_classes=10, input_size=(3, 224, 224)): super().__init__() self.features = nn.Sequential( nn.Conv2d(in_channels, 32, kernel_size=3, padding=1), nn.ReLU(), nn.MaxPool2d(2), nn.Conv2d(32, 64, kernel_size=3, padding=1), nn.ReLU(), nn.MaxPool2d(2), ) with torch.no_grad(): dummy = torch.zeros(1, *input_size) feat = self.features(dummy) flat_features = int(torch.prod(torch.tensor(feat.shape[1:])).item()) self.classifier = nn.Sequential( nn.Flatten(), nn.Linear(flat_features, 256), nn.ReLU(), nn.Dropout(0.5), nn.Linear(256, num_classes), ) def forward(self, x): x = self.features(x) x = self.classifier(x) return x class MLP(nn.Module): """Simple MLP for tabular CSV data classification.""" def __init__(self, input_dim=58, hidden_dims=(1024, 512, 50), num_classes=10): super().__init__() layers = [] prev = input_dim for h in hidden_dims: layers.append(nn.Linear(prev, h)) layers.append(nn.ReLU()) layers.append(nn.Dropout(0.2)) prev = h layers.append(nn.Linear(prev, num_classes)) self.net = nn.Sequential(*layers) def forward(self, x): return self.net(x) def load_model(model_path, model_class, input_dim=None): """ Load the model weights from the given path and initialize the model class. Behavior: - If the checkpoint contains 'model_config', use it to build the model. - Otherwise infer input_dim / hidden_dims / num_classes from the state_dict shapes. - model_class must be MLP or SimpleCNN; for MLP input_dim may be inferred if not provided. """ import torch if not os.path.exists(model_path): raise FileNotFoundError(f"model file not found: {model_path}") ckpt = torch.load(model_path, map_location=torch.device('cpu')) # locate state dict state = None for k in ('model_state_dict', 'state_dict', 'model'): if k in ckpt and isinstance(ckpt[k], dict): state = ckpt[k] break if state is None: # maybe the file directly contains the state_dict if isinstance(ckpt, dict) and any(k.endswith('.weight') for k in ckpt.keys()): state = ckpt else: raise ValueError("No state_dict found in checkpoint") # prefer explicit model_config if present model_config = ckpt.get('model_config') or ckpt.get('config') or {} # helper to infer MLP params from state_dict if no config provided def _infer_mlp_from_state(state_dict): # collect net.*.weight keys (MLP uses 'net' module) weight_items = [] for k in state_dict.keys(): if k.endswith('.weight') and k.startswith('net.'): try: idx = int(k.split('.')[1]) except Exception: continue weight_items.append((idx, k)) if not weight_items: # fallback: take all weight-like keys in order weight_items = [(i, k) for i, k in enumerate(sorted([k for k in state_dict.keys() if k.endswith('.weight')]))] weight_items.sort() shapes = [tuple(state_dict[k].shape) for _, k in weight_items] # shapes are (out, in) for each Linear if not shapes: raise ValueError("Cannot infer MLP structure from state_dict") input_dim_inferred = int(shapes[0][1]) hidden_dims_inferred = [int(s[0]) for s in shapes[:-1]] # all but last are hidden layer outputs num_classes_inferred = int(shapes[-1][0]) return input_dim_inferred, tuple(hidden_dims_inferred), num_classes_inferred # instantiate model if model_class == MLP: # prefer values from model_config cfg_input_dim = model_config.get('input_dim') cfg_hidden = model_config.get('hidden_dims') or model_config.get('hidden_dim') or model_config.get('hidden') cfg_num_classes = model_config.get('num_classes') use_input_dim = input_dim or cfg_input_dim use_hidden = cfg_hidden use_num_classes = cfg_num_classes if use_input_dim is None or use_num_classes is None: # infer from state inferred_input, inferred_hidden, inferred_num = _infer_mlp_from_state(state) if use_input_dim is None: use_input_dim = inferred_input if use_hidden is None: use_hidden = inferred_hidden if use_num_classes is None: use_num_classes = inferred_num # normalize hidden dims to tuple if needed if use_hidden is None: use_hidden = (256, 128) elif isinstance(use_hidden, (list, tuple)): use_hidden = tuple(use_hidden) else: # sometimes stored as string try: use_hidden = tuple(int(x) for x in str(use_hidden).strip('()[]').split(',') if x) except Exception: use_hidden = (256, 128) model = MLP(input_dim=int(use_input_dim), hidden_dims=use_hidden, num_classes=int(use_num_classes)) elif model_class == SimpleCNN: # use model_config if present cfg_num_classes = model_config.get('num_classes') or 10 cfg_input_size = model_config.get('input_size') or (3, 224, 224) model = SimpleCNN(in_channels=cfg_input_size[0], num_classes=int(cfg_num_classes), input_size=tuple(cfg_input_size)) else: raise ValueError(f"Unsupported model class: {model_class}") # load weights into model try: model.load_state_dict(state) except Exception as e: # provide helpful diagnostics model_keys = list(model.state_dict().keys())[:50] state_keys = list(state.keys())[:50] raise RuntimeError(f"Failed to load state_dict: {e}. model_keys_sample={model_keys}, state_keys_sample={state_keys}") return model # Helper function to create different types of models def create_model(device=None, in_channels=3, num_classes=10, input_size=(3, 224, 224), model_type='cnn', input_dim=None, hidden_dims=None): """ Creates and returns a model based on the provided configuration. Args: device (str or torch.device, optional): The device to run the model on ('cpu' or 'cuda'). in_channels (int, optional): The number of input channels (default 3 for RGB images). num_classes (int, optional): The number of output classes (default 10). input_size (tuple, optional): The input size for the model (default (3, 224, 224)). model_type (str, optional): The type of model ('cnn' for convolutional, 'mlp' for multi-layer perceptron). input_dim (int, optional): The input dimension for the MLP (used only if `model_type == 'mlp'`). hidden_dims (tuple, optional): The dimensions of hidden layers for the MLP (used only if `model_type == 'mlp'`). Returns: model (nn.Module): The created model. """ if model_type == 'mlp': if input_dim is None: raise ValueError('input_dim is required for mlp model_type') model = MLP(input_dim=input_dim, hidden_dims=hidden_dims or (256, 128), num_classes=num_classes) else: model = SimpleCNN(in_channels=in_channels, num_classes=num_classes, input_size=input_size) if device: model.to(device) return model # Example for using load_model and create_model in the codebase: # Loading a model # model = load_model('path_to_model.pth', SimpleCNN, device='cuda') # Creating a model for inference # model = create_model(device='cuda', model_type='cnn', num_classes=5)