Added Weather API

This commit is contained in:
Pranav Malladi
2025-09-27 18:13:53 -04:00
parent 2471610d80
commit 629444c382
22 changed files with 629 additions and 308 deletions

View File

@@ -9,8 +9,10 @@ from data import ImageFolderDataset, CSVDataset
from models import create_model
def train(dataset_root, epochs=3, batch_size=16, lr=1e-3, device=None, num_classes=10, model_type='cnn', csv_label='label', generate_labels=False, n_buckets=100, label_method='md5', label_store=None, feature_engineer=False, lat_lon_bins=20, nrows=None, seed=42, hidden_dims=None, weight_decay=0.0):
def train(dataset_root, epochs=3, batch_size=16, lr=1e-3, device=None, num_classes=10, model_type='mlp', csv_label='label', generate_labels=False, n_buckets=100, label_method='md5', label_store=None, feature_engineer=False, lat_lon_bins=20, nrows=None, seed=42, hidden_dims=None, weight_decay=0.0, output_dir=None):
device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
output_dir = output_dir or os.getcwd()
os.makedirs(output_dir, exist_ok=True)
# Detect CSV vs folder dataset
if os.path.isfile(dataset_root) and dataset_root.lower().endswith('.csv'):
dataset = CSVDataset(dataset_root,
@@ -45,7 +47,7 @@ def train(dataset_root, epochs=3, batch_size=16, lr=1e-3, device=None, num_class
# persist preprocessing metadata so inference can reuse identical stats
try:
import numpy as _np
meta_path = os.path.join(os.getcwd(), 'preprocess_meta.npz')
meta_path = os.path.join(output_dir, 'preprocess_meta.npz')
_np.savez_compressed(meta_path, feature_columns=_np.array(dataset.feature_columns, dtype=object), means=dataset.feature_means, stds=dataset.feature_stds)
print(f'Saved preprocess meta to {meta_path}')
except Exception:
@@ -60,6 +62,35 @@ def train(dataset_root, epochs=3, batch_size=16, lr=1e-3, device=None, num_class
model_num_classes = n_buckets
else:
model_num_classes = n_buckets if generate_labels else num_classes
# If labels were generated, save label metadata + assignments (if not huge)
if generate_labels:
try:
label_info = {
"generated": True,
"label_method": label_method,
"n_buckets": n_buckets,
}
# save per-sample assignments if dataset exposes them
if hasattr(dataset, "labels"):
try:
# convert to list (JSON serializable)
assignments = dataset.labels.cpu().numpy().tolist() if hasattr(dataset.labels, "cpu") else dataset.labels.tolist()
# if too large, save as .npz instead
if len(assignments) <= 100000:
label_info["assignments"] = assignments
else:
import numpy as _np
arr_path = os.path.join(output_dir, "label_assignments.npz")
_np.savez_compressed(arr_path, assignments=_np.array(assignments))
label_info["assignments_file"] = os.path.basename(arr_path)
except Exception:
pass
with open(os.path.join(output_dir, "label_info.json"), "w") as f:
import json
json.dump(label_info, f)
print(f"Saved label_info to {os.path.join(output_dir, 'label_info.json')}")
except Exception:
pass
# parse hidden_dims if provided by caller (tuple or list)
model = create_model(device=device, model_type='mlp', input_dim=input_dim, num_classes=model_num_classes, hidden_dims=hidden_dims)
else:
@@ -112,11 +143,23 @@ def train(dataset_root, epochs=3, batch_size=16, lr=1e-3, device=None, num_class
# save best
if val_acc > best_val_acc:
out_path = os.path.join(os.getcwd(), 'model.pth')
out_path = os.path.join(output_dir, 'model.pth')
# include useful metadata so evaluator can reconstruct
meta = {
'model_state_dict': model.state_dict(),
'model_type': model_type,
'model_config': {
'input_dim': input_dim if model_type == 'mlp' else None,
'num_classes': model_num_classes,
'hidden_dims': hidden_dims,
}
}
if hasattr(dataset, 'class_to_idx'):
meta = {'model_state_dict': model.state_dict(), 'class_to_idx': dataset.class_to_idx}
else:
meta = {'model_state_dict': model.state_dict()}
meta['class_to_idx'] = dataset.class_to_idx
# also record paths to saved preprocess and label info (if present)
meta['preprocess_meta'] = os.path.basename(os.path.join(output_dir, 'preprocess_meta.npz'))
if os.path.exists(os.path.join(output_dir, 'label_info.json')):
meta['label_info'] = json.load(open(os.path.join(output_dir, 'label_info.json'), 'r'))
torch.save(meta, out_path)
best_val_acc = val_acc
best_path = out_path
@@ -127,6 +170,7 @@ def train(dataset_root, epochs=3, batch_size=16, lr=1e-3, device=None, num_class
if __name__ == '__main__':
import argparse
import json
parser = argparse.ArgumentParser()
parser.add_argument('data_root')
parser.add_argument('--epochs', type=int, default=3)
@@ -144,6 +188,7 @@ if __name__ == '__main__':
parser.add_argument('--seed', type=int, default=42, help='Random seed for experiments')
parser.add_argument('--hidden-dims', type=str, default='', help='Comma-separated hidden dims for MLP, e.g. "256,128"')
parser.add_argument('--weight-decay', type=float, default=0.0, help='Weight decay (L2) for optimizer')
parser.add_argument('--output-dir', default='.', help='Directory to save output files')
args = parser.parse_args()
data_root = args.data_root
nrows = args.subset if args.subset > 0 else None
@@ -154,5 +199,14 @@ if __name__ == '__main__':
hidden_dims = tuple(int(x) for x in args.hidden_dims.split(',') if x.strip())
except Exception:
hidden_dims = None
train(data_root, epochs=args.epochs, batch_size=args.batch_size, lr=args.lr, model_type=args.model_type, csv_label=args.csv_label, generate_labels=args.generate_labels, n_buckets=args.n_buckets, label_method=args.label_method, label_store=args.label_store, feature_engineer=args.feature_engineer, lat_lon_bins=args.lat_lon_bins, nrows=nrows, seed=args.seed, hidden_dims=hidden_dims, weight_decay=args.weight_decay)
if args.generate_labels:
os.makedirs(args.output_dir, exist_ok=True)
label_info = {
"generated": True,
"label_method": args.label_method,
"n_buckets": args.n_buckets,
}
with open(os.path.join(args.output_dir, "label_info.json"), "w") as f:
json.dump(label_info, f)
train(data_root, epochs=args.epochs, batch_size=args.batch_size, lr=args.lr, model_type=args.model_type, csv_label=args.csv_label, generate_labels=args.generate_labels, n_buckets=args.n_buckets, label_method=args.label_method, label_store=args.label_store, feature_engineer=args.feature_engineer, lat_lon_bins=args.lat_lon_bins, nrows=nrows, seed=args.seed, hidden_dims=hidden_dims, weight_decay=args.weight_decay, output_dir=args.output_dir)