This repository has been archived on 2023-06-18. You can view files and clone it, but cannot push or open issues or pull requests.
ima01/silhouette.py

148 lines
5.1 KiB
Python
Executable File

#!/usr/bin/env python3
from sklearn.cluster import AgglomerativeClustering
from sklearn.metrics import silhouette_score
import numpy as np
import glob
import os
import pandas as pd
import argparse
from k_means import cluster_kmeans
from hierarchical import cluster_hierarchical
from collections import Counter
import seaborn as sns
import matplotlib.pyplot as plt
DIR: str = os.path.dirname(os.path.realpath(__file__))
OUT_DIR: str = DIR + '/clustering'
IN_DIR: str = DIR + '/feature_vectors'
K_MAX: int = 65
def clean_output():
filelist = glob.glob(OUT_DIR + '/*_silhouette.csv')
for f in filelist:
os.remove(f)
filelist = glob.glob(OUT_DIR + '/*.png')
for f in filelist:
os.remove(f)
def validate(path: str, clazz_name: str, autorun: bool, df_table):
df = pd.DataFrame(columns=['k_means', 'hierarchical'], dtype=float)
df_stats = pd.DataFrame(columns=['algorithm', 'k', 'min', 'mean', 'max'])
def add_stat(algo: str, k: int, Y: any, i: int):
y_occurs = list(Counter(Y).values()) # count number of elements in each cluster
df_stats.loc[i, :] = [algo, k, np.min(y_occurs), np.mean(y_occurs), np.max(y_occurs)]
# We bound the number of clusters by the number of distinct points in our dataset.
# To count them, we compute the number of "distinct" feature vectors and we
# bound to the minimum of K_MAX and this number.
nodup = pd.read_csv(path, index_col=0).drop_duplicates()
max_distinct = len(nodup)
limit = min(K_MAX, max_distinct)
i: int = 0
for n in range(2, limit):
X_h, Y_h = cluster_hierarchical(path, n, save_to_disk=False)
add_stat('hierarchical', n, Y_h, i)
i += 1
df.loc[n, 'hierarchical'] = silhouette_score(X_h, Y_h)
X_k, Y_k = cluster_kmeans(path, n, save_to_disk=False)
add_stat('k_means', n, Y_k, i)
i += 1
df.loc[n, 'k_means'] = silhouette_score(X_k, Y_k)
k_kmeans = df[['k_means']].idxmax()[0]
k_hierarchical = df[['hierarchical']].idxmax()[0]
df_table.loc[clazz_name] = [k_kmeans, 0, k_hierarchical, 0]
df.to_csv(OUT_DIR + '/' + clazz_name + '_silhouette.csv')
df_stats.to_csv(OUT_DIR + '/' + clazz_name + '_stats.csv')
if autorun:
cluster_hierarchical(path, k_hierarchical)
cluster_kmeans(path, k_kmeans)
# Plot stats
sns.set_theme(palette="hls")
# Initialize the matplotlib figure
f = plt.figure(figsize=(14, 12))
gs = f.add_gridspec(2, 2)
ax1 = f.add_subplot(gs[0, 0])
ax2 = f.add_subplot(gs[0, 1])
ax3 = f.add_subplot(gs[1, :])
df_k = df_stats.loc[df_stats.algorithm == 'k_means', ['k', 'min', 'mean', 'max']].set_index('k', drop=True)
df_h = df_stats.loc[df_stats.algorithm == 'hierarchical', ['k', 'min', 'mean', 'max']].set_index('k', drop=True)
sns.lineplot(data=df_k, palette="tab10", ax=ax1)
sns.lineplot(data=df_h, palette="tab10", ax=ax2)
sns.lineplot(data=df, palette="tab10", ax=ax3)
# Add a legend and informative axis label
ax1.set(ylabel="# of elements", ylim=[0, 130], xlabel="# of clusters", xlim=[2, limit])
ax1.set_title("K-Means cluster sizes")
ax2.set(ylabel="# of elements", ylim=[0, 130], xlabel="# of clusters", xlim=[2, limit])
ax2.set_title("Hierarchical cluster sizes")
ax3.set(ylabel="Silhouette", ylim=[0, 1], xlabel="# of clusters", xlim=[2, limit])
ax3.set_title("Silhouette metrics per # of clusters")
sns.despine(left=True, bottom=True)
f.savefig(OUT_DIR + '/' + clazz_name + '_stats.png')
plt.clf()
def compute_silhouette(path: str, clazz_name: str, suffix: str) -> float:
df_y = pd.read_csv(OUT_DIR + '/' + clazz_name + '_' + suffix + '.csv')
Y = df_y.iloc[:, 1].values
df = pd.read_csv(path)
X = df.drop(df.columns[0], axis=1).to_numpy()
s = round(silhouette_score(X, Y), 4)
print("Silhouette for " + suffix + ": " + str(s))
return s
def main():
parser = argparse.ArgumentParser(description='Compute silhouette metric.')
parser.add_argument('--validate', action='store_true',
help='compute optimal k for each algorithm')
parser.add_argument('--autorun', action='store_true',
help='if validating, computes CSV for optimal clustering automatically')
args = parser.parse_args()
if args.validate:
clean_output()
df_table = pd.DataFrame(columns=['KMeans K', 'KMeans silhouette', 'Hierarchical K', 'Hierarchical silhouette'])
filelist = glob.glob(IN_DIR + '/*.csv')
for f in filelist:
clazz_name = os.path.basename(f)
clazz_name = clazz_name[:clazz_name.rfind('.')]
if args.validate:
validate(f, clazz_name, args.autorun, df_table)
sk = compute_silhouette(f, clazz_name, 'kmeans')
sh = compute_silhouette(f, clazz_name, 'hierarchical')
if args.validate:
df_table.loc[clazz_name, 'KMeans silhouette'] = sk
df_table.loc[clazz_name, 'Hierarchical silhouette'] = sh
df_table.index.name = 'Class Name'
print(df_table.to_markdown())
if __name__ == '__main__':
main()