90 lines
2.5 KiB
Python
90 lines
2.5 KiB
Python
|
#!/usr/bin/env python3
|
||
|
from sklearn.cluster import AgglomerativeClustering
|
||
|
from sklearn.metrics import silhouette_score
|
||
|
import numpy as np
|
||
|
import glob
|
||
|
import os
|
||
|
import pandas as pd
|
||
|
import argparse
|
||
|
from k_means import cluster_kmeans
|
||
|
from hierarchical import cluster_hierarchical
|
||
|
|
||
|
DIR: str = os.path.dirname(os.path.realpath(__file__))
|
||
|
OUT_DIR: str = DIR + '/clustering'
|
||
|
IN_DIR: str = DIR + '/feature_vectors'
|
||
|
|
||
|
K_MAX: int = 65
|
||
|
|
||
|
|
||
|
def clean_output():
|
||
|
filelist = glob.glob(OUT_DIR + '/*_silhouette.csv')
|
||
|
for f in filelist:
|
||
|
os.remove(f)
|
||
|
|
||
|
|
||
|
def validate(path: str, clazz_name: str, autorun: bool):
|
||
|
df = pd.DataFrame(columns=['k_means', 'hierarchical'], dtype=float)
|
||
|
|
||
|
for n in range(2, K_MAX):
|
||
|
X_h, Y_h = cluster_hierarchical(path, n, save_to_disk=False)
|
||
|
X_k, Y_k = cluster_kmeans(path, n, save_to_disk=False)
|
||
|
|
||
|
df.loc[n, 'k_means'] = silhouette_score(X_h, Y_h)
|
||
|
df.loc[n, 'hierarchical'] = silhouette_score(X_k, Y_k)
|
||
|
|
||
|
k_kmeans = df[['k_means']].idxmax()[0]
|
||
|
k_hierarchical = df[['hierarchical']].idxmax()[0]
|
||
|
|
||
|
print("K_means optimal value: " + str(k_kmeans))
|
||
|
print("Hierarchical optimal value: " + str(k_hierarchical))
|
||
|
|
||
|
df.to_csv(OUT_DIR + '/' + clazz_name + '_silhouette.csv')
|
||
|
|
||
|
if autorun:
|
||
|
cluster_hierarchical(path, k_hierarchical)
|
||
|
cluster_kmeans(path, k_kmeans)
|
||
|
|
||
|
|
||
|
|
||
|
def compute_silhouette(path: str, clazz_name: str, suffix: str):
|
||
|
df_y = pd.read_csv(OUT_DIR + '/' + clazz_name + '_' + suffix + '.csv')
|
||
|
Y = df_y.iloc[:, 1].values
|
||
|
|
||
|
df = pd.read_csv(path)
|
||
|
X = df.drop(df.columns[0], axis=1).to_numpy()
|
||
|
|
||
|
print("Silhouette for " + suffix + ": " + str(silhouette_score(X, Y)))
|
||
|
|
||
|
|
||
|
def main():
|
||
|
parser = argparse.ArgumentParser(description='Compute silhouette metric.')
|
||
|
parser.add_argument('--validate', action='store_true',
|
||
|
help='compute optimal k for each algorithm')
|
||
|
parser.add_argument('--autorun', action='store_true',
|
||
|
help='if validating, computes CSV for optimal clustering automatically')
|
||
|
|
||
|
|
||
|
args = parser.parse_args()
|
||
|
|
||
|
if args.validate:
|
||
|
clean_output()
|
||
|
|
||
|
filelist = glob.glob(IN_DIR + '/*.csv')
|
||
|
for f in filelist:
|
||
|
clazz_name = os.path.basename(f)
|
||
|
clazz_name = clazz_name[:clazz_name.rfind('.')]
|
||
|
|
||
|
print(clazz_name)
|
||
|
|
||
|
if args.validate:
|
||
|
validate(f, clazz_name, args.autorun)
|
||
|
|
||
|
compute_silhouette(f, clazz_name, 'kmeans')
|
||
|
compute_silhouette(f, clazz_name, 'hierarchical')
|
||
|
|
||
|
print()
|
||
|
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
main()
|