ima-preparation/god-2022/silhouette.py

90 lines
2.5 KiB
Python
Executable File

#!/usr/bin/env python3
from sklearn.cluster import AgglomerativeClustering
from sklearn.metrics import silhouette_score
import numpy as np
import glob
import os
import pandas as pd
import argparse
from k_means import cluster_kmeans
from hierarchical import cluster_hierarchical
DIR: str = os.path.dirname(os.path.realpath(__file__))
OUT_DIR: str = DIR + '/clustering'
IN_DIR: str = DIR + '/feature_vectors'
K_MAX: int = 65
def clean_output():
filelist = glob.glob(OUT_DIR + '/*_silhouette.csv')
for f in filelist:
os.remove(f)
def validate(path: str, clazz_name: str, autorun: bool):
df = pd.DataFrame(columns=['k_means', 'hierarchical'], dtype=float)
for n in range(2, K_MAX):
X_h, Y_h = cluster_hierarchical(path, n, save_to_disk=False)
X_k, Y_k = cluster_kmeans(path, n, save_to_disk=False)
df.loc[n, 'k_means'] = silhouette_score(X_h, Y_h)
df.loc[n, 'hierarchical'] = silhouette_score(X_k, Y_k)
k_kmeans = df[['k_means']].idxmax()[0]
k_hierarchical = df[['hierarchical']].idxmax()[0]
print("K_means optimal value: " + str(k_kmeans))
print("Hierarchical optimal value: " + str(k_hierarchical))
df.to_csv(OUT_DIR + '/' + clazz_name + '_silhouette.csv')
if autorun:
cluster_hierarchical(path, k_hierarchical)
cluster_kmeans(path, k_kmeans)
def compute_silhouette(path: str, clazz_name: str, suffix: str):
df_y = pd.read_csv(OUT_DIR + '/' + clazz_name + '_' + suffix + '.csv')
Y = df_y.iloc[:, 1].values
df = pd.read_csv(path)
X = df.drop(df.columns[0], axis=1).to_numpy()
print("Silhouette for " + suffix + ": " + str(silhouette_score(X, Y)))
def main():
parser = argparse.ArgumentParser(description='Compute silhouette metric.')
parser.add_argument('--validate', action='store_true',
help='compute optimal k for each algorithm')
parser.add_argument('--autorun', action='store_true',
help='if validating, computes CSV for optimal clustering automatically')
args = parser.parse_args()
if args.validate:
clean_output()
filelist = glob.glob(IN_DIR + '/*.csv')
for f in filelist:
clazz_name = os.path.basename(f)
clazz_name = clazz_name[:clazz_name.rfind('.')]
print(clazz_name)
if args.validate:
validate(f, clazz_name, args.autorun)
compute_silhouette(f, clazz_name, 'kmeans')
compute_silhouette(f, clazz_name, 'hierarchical')
print()
if __name__ == '__main__':
main()