26 lines
1.3 KiB
Plaintext
26 lines
1.3 KiB
Plaintext
train the model:
|
|
python train.py data.csv --model-type mlp --generate-labels --label-method kmeans --n-buckets 50 --hidden-dims 512,256 --epochs 8 --batch-size 256 --feature-engineer --weight-decay 1e-5 --seed 42
|
|
|
|
python train.py data.csv --model-type mlp --generate-labels --label-method kmeans --n-buckets 50 --hidden-dims 1024,512 --epochs 12 --batch-size 256 --lr 1e-3 --lr-step-size 4 --lr-gamma 0.5 --feature-engineer --weight-decay 1e-5 --seed 42
|
|
|
|
# train with outputs saved to output/
|
|
python train.py data.csv --model-type mlp --generate-labels --label-method kmeans --n-buckets 50 --hidden-dims 512,256 --epochs 8 --batch-size 256 --feature-engineer --weight-decay 1e-5 --seed 42 --output-dir output/
|
|
|
|
# evaluate and visualize:
|
|
python evaluate_and_visualize.py \
|
|
--checkpoint path/to/checkpoint.pt \
|
|
--data data.csv \
|
|
--label-col original_label_column_name \
|
|
--batch-size 256 \
|
|
--sample-index 5 \
|
|
--plot
|
|
|
|
# evaluate
|
|
python evaluate_and_visualize.py --checkpoint output/model.pth --data data.csv --label-col label --plot --sample-index 5
|
|
|
|
# If you used generated labels during training and train.py saved metadata,
|
|
# the evaluator will prefer generated labels saved in label_info.json inside the checkpoint dir.
|
|
|
|
# fetch weather (placeholder)
|
|
# from openweather_client import fetch_road_risk
|
|
# print(fetch_road_risk(37.7749, -122.4194)) |