2023-03-22 13:28:17 +00:00
|
|
|
#!/usr/bin/env python3
|
|
|
|
from sklearn.cluster import AgglomerativeClustering
|
|
|
|
from sklearn.metrics import silhouette_score
|
|
|
|
import numpy as np
|
|
|
|
import glob
|
|
|
|
import os
|
|
|
|
import pandas as pd
|
|
|
|
import argparse
|
|
|
|
from k_means import cluster_kmeans
|
|
|
|
from hierarchical import cluster_hierarchical
|
2023-04-19 14:53:00 +00:00
|
|
|
from collections import Counter
|
|
|
|
import seaborn as sns
|
|
|
|
import matplotlib.pyplot as plt
|
2023-03-22 13:28:17 +00:00
|
|
|
|
|
|
|
DIR: str = os.path.dirname(os.path.realpath(__file__))
|
|
|
|
OUT_DIR: str = DIR + '/clustering'
|
|
|
|
IN_DIR: str = DIR + '/feature_vectors'
|
|
|
|
|
|
|
|
K_MAX: int = 65
|
|
|
|
|
|
|
|
|
|
|
|
def clean_output():
|
|
|
|
filelist = glob.glob(OUT_DIR + '/*_silhouette.csv')
|
|
|
|
for f in filelist:
|
|
|
|
os.remove(f)
|
2023-04-19 14:53:00 +00:00
|
|
|
filelist = glob.glob(OUT_DIR + '/*.png')
|
|
|
|
for f in filelist:
|
|
|
|
os.remove(f)
|
2023-03-22 13:28:17 +00:00
|
|
|
|
|
|
|
|
2023-04-19 14:53:00 +00:00
|
|
|
def validate(path: str, clazz_name: str, autorun: bool, df_table):
|
2023-03-22 13:28:17 +00:00
|
|
|
df = pd.DataFrame(columns=['k_means', 'hierarchical'], dtype=float)
|
2023-04-19 14:53:00 +00:00
|
|
|
df_stats = pd.DataFrame(columns=['algorithm', 'k', 'min', 'mean', 'max'])
|
|
|
|
|
|
|
|
def add_stat(algo: str, k: int, Y: any, i: int):
|
|
|
|
y_occurs = list(Counter(Y).values()) # count number of elements in each cluster
|
|
|
|
df_stats.loc[i, :] = [algo, k, np.min(y_occurs), np.mean(y_occurs), np.max(y_occurs)]
|
2023-03-22 13:28:17 +00:00
|
|
|
|
|
|
|
# We bound the number of clusters by the number of distinct points in our dataset.
|
|
|
|
# To count them, we compute the number of "distinct" feature vectors and we
|
|
|
|
# bound to the minimum of K_MAX and this number.
|
|
|
|
nodup = pd.read_csv(path, index_col=0).drop_duplicates()
|
|
|
|
max_distinct = len(nodup)
|
2023-04-19 14:53:00 +00:00
|
|
|
limit = min(K_MAX, max_distinct)
|
2023-03-22 13:28:17 +00:00
|
|
|
|
2023-04-19 14:53:00 +00:00
|
|
|
i: int = 0
|
|
|
|
for n in range(2, limit):
|
2023-03-22 13:28:17 +00:00
|
|
|
X_h, Y_h = cluster_hierarchical(path, n, save_to_disk=False)
|
2023-04-19 14:53:00 +00:00
|
|
|
add_stat('hierarchical', n, Y_h, i)
|
|
|
|
i += 1
|
|
|
|
df.loc[n, 'hierarchical'] = silhouette_score(X_h, Y_h)
|
2023-03-22 13:28:17 +00:00
|
|
|
|
|
|
|
X_k, Y_k = cluster_kmeans(path, n, save_to_disk=False)
|
2023-04-19 14:53:00 +00:00
|
|
|
add_stat('k_means', n, Y_k, i)
|
|
|
|
i += 1
|
|
|
|
df.loc[n, 'k_means'] = silhouette_score(X_k, Y_k)
|
2023-03-22 13:28:17 +00:00
|
|
|
|
|
|
|
k_kmeans = df[['k_means']].idxmax()[0]
|
|
|
|
k_hierarchical = df[['hierarchical']].idxmax()[0]
|
|
|
|
|
2023-04-19 14:53:00 +00:00
|
|
|
df_table.loc[clazz_name] = [k_kmeans, 0, k_hierarchical, 0]
|
2023-03-22 13:28:17 +00:00
|
|
|
|
|
|
|
df.to_csv(OUT_DIR + '/' + clazz_name + '_silhouette.csv')
|
2023-04-19 14:53:00 +00:00
|
|
|
df_stats.to_csv(OUT_DIR + '/' + clazz_name + '_stats.csv')
|
2023-03-22 13:28:17 +00:00
|
|
|
|
|
|
|
if autorun:
|
|
|
|
cluster_hierarchical(path, k_hierarchical)
|
|
|
|
cluster_kmeans(path, k_kmeans)
|
|
|
|
|
2023-04-19 14:53:00 +00:00
|
|
|
# Plot stats
|
|
|
|
sns.set_theme(palette="hls")
|
2023-03-22 13:28:17 +00:00
|
|
|
|
2023-04-19 14:53:00 +00:00
|
|
|
# Initialize the matplotlib figure
|
|
|
|
f = plt.figure(figsize=(14, 12))
|
|
|
|
gs = f.add_gridspec(2, 2)
|
|
|
|
ax1 = f.add_subplot(gs[0, 0])
|
|
|
|
ax2 = f.add_subplot(gs[0, 1])
|
|
|
|
ax3 = f.add_subplot(gs[1, :])
|
2023-03-22 13:28:17 +00:00
|
|
|
|
2023-04-19 14:53:00 +00:00
|
|
|
df_k = df_stats.loc[df_stats.algorithm == 'k_means', ['k', 'min', 'mean', 'max']].set_index('k', drop=True)
|
|
|
|
df_h = df_stats.loc[df_stats.algorithm == 'hierarchical', ['k', 'min', 'mean', 'max']].set_index('k', drop=True)
|
|
|
|
|
|
|
|
sns.lineplot(data=df_k, palette="tab10", ax=ax1)
|
|
|
|
sns.lineplot(data=df_h, palette="tab10", ax=ax2)
|
|
|
|
sns.lineplot(data=df, palette="tab10", ax=ax3)
|
|
|
|
|
|
|
|
# Add a legend and informative axis label
|
|
|
|
ax1.set(ylabel="# of elements", ylim=[0, 130], xlabel="# of clusters", xlim=[2, limit])
|
|
|
|
ax1.set_title("K-Means cluster sizes")
|
|
|
|
ax2.set(ylabel="# of elements", ylim=[0, 130], xlabel="# of clusters", xlim=[2, limit])
|
|
|
|
ax2.set_title("Hierarchical cluster sizes")
|
|
|
|
ax3.set(ylabel="Silhouette", ylim=[0, 1], xlabel="# of clusters", xlim=[2, limit])
|
|
|
|
ax3.set_title("Silhouette metrics per # of clusters")
|
|
|
|
|
|
|
|
sns.despine(left=True, bottom=True)
|
|
|
|
f.savefig(OUT_DIR + '/' + clazz_name + '_stats.png')
|
|
|
|
plt.clf()
|
|
|
|
|
|
|
|
|
|
|
|
def compute_silhouette(path: str, clazz_name: str, suffix: str) -> float:
|
2023-03-22 13:28:17 +00:00
|
|
|
df_y = pd.read_csv(OUT_DIR + '/' + clazz_name + '_' + suffix + '.csv')
|
|
|
|
Y = df_y.iloc[:, 1].values
|
|
|
|
|
|
|
|
df = pd.read_csv(path)
|
|
|
|
X = df.drop(df.columns[0], axis=1).to_numpy()
|
|
|
|
|
2023-04-19 14:53:00 +00:00
|
|
|
s = round(silhouette_score(X, Y), 4)
|
|
|
|
|
|
|
|
print("Silhouette for " + suffix + ": " + str(s))
|
|
|
|
return s
|
2023-03-22 13:28:17 +00:00
|
|
|
|
|
|
|
|
|
|
|
def main():
|
|
|
|
parser = argparse.ArgumentParser(description='Compute silhouette metric.')
|
|
|
|
parser.add_argument('--validate', action='store_true',
|
|
|
|
help='compute optimal k for each algorithm')
|
|
|
|
parser.add_argument('--autorun', action='store_true',
|
|
|
|
help='if validating, computes CSV for optimal clustering automatically')
|
|
|
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
if args.validate:
|
|
|
|
clean_output()
|
|
|
|
|
2023-04-19 14:53:00 +00:00
|
|
|
df_table = pd.DataFrame(columns=['KMeans K', 'KMeans silhouette', 'Hierarchical K', 'Hierarchical silhouette'])
|
|
|
|
|
2023-03-22 13:28:17 +00:00
|
|
|
filelist = glob.glob(IN_DIR + '/*.csv')
|
|
|
|
for f in filelist:
|
|
|
|
clazz_name = os.path.basename(f)
|
|
|
|
clazz_name = clazz_name[:clazz_name.rfind('.')]
|
|
|
|
|
|
|
|
if args.validate:
|
2023-04-19 14:53:00 +00:00
|
|
|
validate(f, clazz_name, args.autorun, df_table)
|
2023-03-22 13:28:17 +00:00
|
|
|
|
2023-04-19 14:53:00 +00:00
|
|
|
sk = compute_silhouette(f, clazz_name, 'kmeans')
|
|
|
|
sh = compute_silhouette(f, clazz_name, 'hierarchical')
|
|
|
|
|
|
|
|
if args.validate:
|
|
|
|
df_table.loc[clazz_name, 'KMeans silhouette'] = sk
|
|
|
|
df_table.loc[clazz_name, 'Hierarchical silhouette'] = sh
|
2023-03-22 13:28:17 +00:00
|
|
|
|
2023-04-19 14:53:00 +00:00
|
|
|
df_table.index.name = 'Class Name'
|
|
|
|
print(df_table.to_markdown())
|
2023-03-22 13:28:17 +00:00
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
main()
|