#!/usr/bin/env python3 from sklearn.cluster import AgglomerativeClustering from sklearn.metrics import silhouette_score import numpy as np import glob import os import pandas as pd import argparse from k_means import cluster_kmeans from hierarchical import cluster_hierarchical DIR: str = os.path.dirname(os.path.realpath(__file__)) OUT_DIR: str = DIR + '/clustering' IN_DIR: str = DIR + '/feature_vectors' K_MAX: int = 65 def clean_output(): filelist = glob.glob(OUT_DIR + '/*_silhouette.csv') for f in filelist: os.remove(f) def validate(path: str, clazz_name: str, autorun: bool): df = pd.DataFrame(columns=['k_means', 'hierarchical'], dtype=float) for n in range(2, K_MAX): X_h, Y_h = cluster_hierarchical(path, n, save_to_disk=False) X_k, Y_k = cluster_kmeans(path, n, save_to_disk=False) df.loc[n, 'k_means'] = silhouette_score(X_h, Y_h) df.loc[n, 'hierarchical'] = silhouette_score(X_k, Y_k) k_kmeans = df[['k_means']].idxmax()[0] k_hierarchical = df[['hierarchical']].idxmax()[0] print("K_means optimal value: " + str(k_kmeans)) print("Hierarchical optimal value: " + str(k_hierarchical)) df.to_csv(OUT_DIR + '/' + clazz_name + '_silhouette.csv') if autorun: cluster_hierarchical(path, k_hierarchical) cluster_kmeans(path, k_kmeans) def compute_silhouette(path: str, clazz_name: str, suffix: str): df_y = pd.read_csv(OUT_DIR + '/' + clazz_name + '_' + suffix + '.csv') Y = df_y.iloc[:, 1].values df = pd.read_csv(path) X = df.drop(df.columns[0], axis=1).to_numpy() print("Silhouette for " + suffix + ": " + str(silhouette_score(X, Y))) def main(): parser = argparse.ArgumentParser(description='Compute silhouette metric.') parser.add_argument('--validate', action='store_true', help='compute optimal k for each algorithm') parser.add_argument('--autorun', action='store_true', help='if validating, computes CSV for optimal clustering automatically') args = parser.parse_args() if args.validate: clean_output() filelist = glob.glob(IN_DIR + '/*.csv') for f in filelist: clazz_name = os.path.basename(f) clazz_name = clazz_name[:clazz_name.rfind('.')] print(clazz_name) if args.validate: validate(f, clazz_name, args.autorun) compute_silhouette(f, clazz_name, 'kmeans') compute_silhouette(f, clazz_name, 'hierarchical') print() if __name__ == '__main__': main()