#!/usr/bin/env python3
from sklearn.cluster import KMeans
import numpy as np
import glob
import os
import pandas as pd
import argparse

DIR: str = os.path.dirname(os.path.realpath(__file__))
OUT_DIR: str = DIR + '/clustering'
IN_DIR: str = DIR + '/feature_vectors'

RAND_SEED: int = 0


def cluster_kmeans(path: str, n_clusters: int, save_to_disk: bool = True) -> tuple[any, any]:
    clazz_name = os.path.basename(path)
    clazz_name = clazz_name[:clazz_name.rfind('.')]

    df = pd.read_csv(path)
    X = df.drop(df.columns[0], axis=1).to_numpy()
    kmeans = KMeans(n_clusters=n_clusters,
                    random_state=RAND_SEED, n_init='auto').fit(X)

    Y = kmeans.labels_  # array of cluster # assigned to each method

    # combine cluster labels with method name
    assigned = pd.DataFrame(Y, columns=['cluster']).set_axis(
        df.iloc[:, 0].values)

    if save_to_disk:
        assigned.to_csv(OUT_DIR + '/' + clazz_name + '_kmeans.csv')

    return (X, Y,)


def main():
    parser = argparse.ArgumentParser(
        description='Compute k-means clustering')
    parser.add_argument('class_name', type=str, help='name of the god class')
    parser.add_argument('n_clusters', type=int, help='number of clusters')

    args = parser.parse_args()
    path = IN_DIR + '/' + args.class_name + '.csv'

    os.remove(OUT_DIR + '/' + args.class_name + '_kmeans.csv')
    cluster_kmeans(path, args.n_clusters)


if __name__ == '__main__':
    main()