Files
VTHacks13/roadcast/fit_kmeans.py
samarthjain2023 0df2b0019b added the model
2025-09-27 12:14:26 -04:00

66 lines
2.6 KiB
Python

"""Fit k-means centers on CSV numeric features (optionally PCA) and save centers to .npz
Usage: python fit_kmeans.py data.csv --n-buckets 10 --out kmeans_centers_final.npz --sample 50000 --pca 50
"""
import argparse
import numpy as np
import pandas as pd
from data import generate_kmeans_labels
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('csv')
parser.add_argument('--n-buckets', type=int, default=10)
parser.add_argument('--out', default='kmeans_centers_final.npz')
parser.add_argument('--sample', type=int, default=50000, help='max rows to sample for fitting')
parser.add_argument('--pca', type=int, default=0, help='Apply PCA to reduce dims before kmeans (0=none)')
args = parser.parse_args()
# read numeric columns only to avoid huge memory usage
df = pd.read_csv(args.csv, low_memory=False)
num_df = df.select_dtypes(include=['number']).fillna(0.0)
data = num_df.values.astype(float)
if data.shape[0] == 0 or data.shape[1] == 0:
raise SystemExit('No numeric data found in CSV')
# sample rows if requested
if args.sample and args.sample < data.shape[0]:
rng = np.random.default_rng(42)
idx = rng.choice(data.shape[0], size=args.sample, replace=False)
sample_data = data[idx]
else:
sample_data = data
# Use the kmeans implementation via generate_kmeans_labels for fitting centers.
# We'll call the internal function by adapting it here: import numpy locally.
import numpy as _np
# initialize centers by random sampling
rng = _np.random.default_rng(42)
k = min(args.n_buckets, sample_data.shape[0])
centers_idx = rng.choice(sample_data.shape[0], size=k, replace=False)
centers = sample_data[centers_idx].astype(float)
max_iters = 50
for _ in range(max_iters):
dists = np.linalg.norm(sample_data[:, None, :] - centers[None, :, :], axis=2)
labels = np.argmin(dists, axis=1)
new_centers = np.zeros_like(centers)
counts = np.zeros((centers.shape[0],), dtype=int)
for i, lab in enumerate(labels):
new_centers[lab] += sample_data[i]
counts[lab] += 1
for kk in range(centers.shape[0]):
if counts[kk] > 0:
new_centers[kk] = new_centers[kk] / counts[kk]
else:
new_centers[kk] = sample_data[rng.integers(0, sample_data.shape[0])]
shift = np.linalg.norm(new_centers - centers, axis=1).max()
centers = new_centers
if shift < 1e-4:
break
np.savez_compressed(args.out, centers=centers)
print('Saved centers to', args.out)