added the model
This commit is contained in:
65
roadcast/fit_kmeans.py
Normal file
65
roadcast/fit_kmeans.py
Normal file
@@ -0,0 +1,65 @@
|
||||
"""Fit k-means centers on CSV numeric features (optionally PCA) and save centers to .npz
|
||||
|
||||
Usage: python fit_kmeans.py data.csv --n-buckets 10 --out kmeans_centers_final.npz --sample 50000 --pca 50
|
||||
"""
|
||||
import argparse
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from data import generate_kmeans_labels
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('csv')
|
||||
parser.add_argument('--n-buckets', type=int, default=10)
|
||||
parser.add_argument('--out', default='kmeans_centers_final.npz')
|
||||
parser.add_argument('--sample', type=int, default=50000, help='max rows to sample for fitting')
|
||||
parser.add_argument('--pca', type=int, default=0, help='Apply PCA to reduce dims before kmeans (0=none)')
|
||||
args = parser.parse_args()
|
||||
|
||||
# read numeric columns only to avoid huge memory usage
|
||||
df = pd.read_csv(args.csv, low_memory=False)
|
||||
num_df = df.select_dtypes(include=['number']).fillna(0.0)
|
||||
data = num_df.values.astype(float)
|
||||
if data.shape[0] == 0 or data.shape[1] == 0:
|
||||
raise SystemExit('No numeric data found in CSV')
|
||||
|
||||
# sample rows if requested
|
||||
if args.sample and args.sample < data.shape[0]:
|
||||
rng = np.random.default_rng(42)
|
||||
idx = rng.choice(data.shape[0], size=args.sample, replace=False)
|
||||
sample_data = data[idx]
|
||||
else:
|
||||
sample_data = data
|
||||
|
||||
# Use the kmeans implementation via generate_kmeans_labels for fitting centers.
|
||||
# We'll call the internal function by adapting it here: import numpy locally.
|
||||
import numpy as _np
|
||||
|
||||
# initialize centers by random sampling
|
||||
rng = _np.random.default_rng(42)
|
||||
k = min(args.n_buckets, sample_data.shape[0])
|
||||
centers_idx = rng.choice(sample_data.shape[0], size=k, replace=False)
|
||||
centers = sample_data[centers_idx].astype(float)
|
||||
|
||||
max_iters = 50
|
||||
for _ in range(max_iters):
|
||||
dists = np.linalg.norm(sample_data[:, None, :] - centers[None, :, :], axis=2)
|
||||
labels = np.argmin(dists, axis=1)
|
||||
new_centers = np.zeros_like(centers)
|
||||
counts = np.zeros((centers.shape[0],), dtype=int)
|
||||
for i, lab in enumerate(labels):
|
||||
new_centers[lab] += sample_data[i]
|
||||
counts[lab] += 1
|
||||
for kk in range(centers.shape[0]):
|
||||
if counts[kk] > 0:
|
||||
new_centers[kk] = new_centers[kk] / counts[kk]
|
||||
else:
|
||||
new_centers[kk] = sample_data[rng.integers(0, sample_data.shape[0])]
|
||||
shift = np.linalg.norm(new_centers - centers, axis=1).max()
|
||||
centers = new_centers
|
||||
if shift < 1e-4:
|
||||
break
|
||||
|
||||
np.savez_compressed(args.out, centers=centers)
|
||||
print('Saved centers to', args.out)
|
||||
Reference in New Issue
Block a user