This repository has been archived on 2023-06-18. You can view files and clone it, but cannot push or open issues or pull requests.
ima01/hierarchical.py

49 lines
1.4 KiB
Python
Executable file

#!/usr/bin/env python3
from sklearn.cluster import AgglomerativeClustering
import numpy as np
import glob
import os
import pandas as pd
import argparse
DIR: str = os.path.dirname(os.path.realpath(__file__))
OUT_DIR: str = DIR + '/clustering'
IN_DIR: str = DIR + '/feature_vectors'
def cluster_hierarchical(path: str, n_clusters: int, save_to_disk: bool = True) -> tuple[any, any]:
clazz_name = os.path.basename(path)
clazz_name = clazz_name[:clazz_name.rfind('.')]
df = pd.read_csv(path)
X = df.drop(df.columns[0], axis=1).to_numpy()
kmeans = AgglomerativeClustering(
n_clusters=n_clusters, linkage='complete').fit(X)
Y = kmeans.labels_ # array of cluster # assigned to each method
# combine cluster labels with method name
assigned = pd.DataFrame(Y, columns=['cluster']).set_axis(
df.iloc[:, 0].values)
if save_to_disk:
assigned.to_csv(OUT_DIR + '/' + clazz_name + '_hierarchical.csv')
return (X, Y,)
def main():
parser = argparse.ArgumentParser(
description='Compute agglomerative clustering')
parser.add_argument('class_name', type=str, help='name of the god class')
parser.add_argument('n_clusters', type=int, help='number of clusters')
args = parser.parse_args()
path = IN_DIR + '/' + args.class_name + '.csv'
os.remove(OUT_DIR + '/' + args.class_name + '_hierarchical.csv')
cluster_hierarchical(path, args.n_clusters)
if __name__ == '__main__':
main()