66 lines
2.6 KiB
Python
66 lines
2.6 KiB
Python
"""Fit k-means centers on CSV numeric features (optionally PCA) and save centers to .npz
|
|
|
|
Usage: python fit_kmeans.py data.csv --n-buckets 10 --out kmeans_centers_final.npz --sample 50000 --pca 50
|
|
"""
|
|
import argparse
|
|
import numpy as np
|
|
import pandas as pd
|
|
|
|
from data import generate_kmeans_labels
|
|
|
|
if __name__ == '__main__':
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('csv')
|
|
parser.add_argument('--n-buckets', type=int, default=10)
|
|
parser.add_argument('--out', default='kmeans_centers_final.npz')
|
|
parser.add_argument('--sample', type=int, default=50000, help='max rows to sample for fitting')
|
|
parser.add_argument('--pca', type=int, default=0, help='Apply PCA to reduce dims before kmeans (0=none)')
|
|
args = parser.parse_args()
|
|
|
|
# read numeric columns only to avoid huge memory usage
|
|
df = pd.read_csv(args.csv, low_memory=False)
|
|
num_df = df.select_dtypes(include=['number']).fillna(0.0)
|
|
data = num_df.values.astype(float)
|
|
if data.shape[0] == 0 or data.shape[1] == 0:
|
|
raise SystemExit('No numeric data found in CSV')
|
|
|
|
# sample rows if requested
|
|
if args.sample and args.sample < data.shape[0]:
|
|
rng = np.random.default_rng(42)
|
|
idx = rng.choice(data.shape[0], size=args.sample, replace=False)
|
|
sample_data = data[idx]
|
|
else:
|
|
sample_data = data
|
|
|
|
# Use the kmeans implementation via generate_kmeans_labels for fitting centers.
|
|
# We'll call the internal function by adapting it here: import numpy locally.
|
|
import numpy as _np
|
|
|
|
# initialize centers by random sampling
|
|
rng = _np.random.default_rng(42)
|
|
k = min(args.n_buckets, sample_data.shape[0])
|
|
centers_idx = rng.choice(sample_data.shape[0], size=k, replace=False)
|
|
centers = sample_data[centers_idx].astype(float)
|
|
|
|
max_iters = 50
|
|
for _ in range(max_iters):
|
|
dists = np.linalg.norm(sample_data[:, None, :] - centers[None, :, :], axis=2)
|
|
labels = np.argmin(dists, axis=1)
|
|
new_centers = np.zeros_like(centers)
|
|
counts = np.zeros((centers.shape[0],), dtype=int)
|
|
for i, lab in enumerate(labels):
|
|
new_centers[lab] += sample_data[i]
|
|
counts[lab] += 1
|
|
for kk in range(centers.shape[0]):
|
|
if counts[kk] > 0:
|
|
new_centers[kk] = new_centers[kk] / counts[kk]
|
|
else:
|
|
new_centers[kk] = sample_data[rng.integers(0, sample_data.shape[0])]
|
|
shift = np.linalg.norm(new_centers - centers, axis=1).max()
|
|
centers = new_centers
|
|
if shift < 1e-4:
|
|
break
|
|
|
|
np.savez_compressed(args.out, centers=centers)
|
|
print('Saved centers to', args.out)
|