diff --git a/proj2022/prec_recall.py b/proj2022/prec_recall.py new file mode 100755 index 0000000..f81c325 --- /dev/null +++ b/proj2022/prec_recall.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 +import numpy as np +import glob +import os +import pandas as pd + +DIR: str = os.path.dirname(os.path.realpath(__file__)) +IN_DIR: str = DIR + '/clustering' +OUT_DIR: str = DIR + '' + + +def intrapairs(path: str) -> set[set[str, str]]: + df = pd.read_csv(path) + clusters: list[list[str]] = df.groupby( + 'cluster').agg(list).iloc[:, 0].values + + intrapairs: set[set[str]] = set() # inner sets always contain 2 elements + for cluster in clusters: + for i, e1 in enumerate(cluster): + for j in range(i + 1, len(cluster)): + e2 = cluster[j] + intrapairs.add(frozenset((e1, e2,))) + return intrapairs + + +def main(): + filelist = glob.glob(IN_DIR + '/*_groundtruth.csv') + for f in filelist: + clazz_name = os.path.basename(f) + clazz_name = clazz_name[:clazz_name.rfind('_groundtruth.csv')] + print(clazz_name) + + ground_pairs = intrapairs(f) + for method in ['kmeans', 'hierarchical']: + cluster_pairs = intrapairs( + IN_DIR + '/' + clazz_name + '_' + method + '.csv') + + n_common = len(ground_pairs.intersection(cluster_pairs)) + precision = n_common / len(cluster_pairs) + recall = n_common / len(ground_pairs) + + print(method + " precision: " + str(precision)) + print(method + " recall: " + str(recall)) + + print() + + +if __name__ == '__main__': + main()