This repository has been archived on 2023-06-18. You can view files and clone it, but cannot push or open issues or pull requests.
ima01/prec_recall.py

61 lines
1.9 KiB
Python
Executable File

#!/usr/bin/env python3
import numpy as np
import glob
import os
import pandas as pd
DIR: str = os.path.dirname(os.path.realpath(__file__))
IN_DIR: str = DIR + '/clustering'
OUT_DIR: str = DIR + ''
def intrapairs(path: str) -> set[set[str, str]]:
df = pd.read_csv(path)
clusters: list[list[str]] = df.groupby(
'cluster').agg(list).iloc[:, 0].values
intrapairs: set[set[str]] = set() # inner sets always contain 2 elements
for cluster in clusters:
for i, e1 in enumerate(cluster):
for j in range(i + 1, len(cluster)):
e2 = cluster[j]
intrapairs.add(frozenset((e1, e2,)))
return intrapairs
def main():
filelist = glob.glob(IN_DIR + '/*_groundtruth.csv')
df_table = pd.DataFrame(columns=pd.MultiIndex.from_tuples([
('KMeans', 'Precision'),
('KMeans', 'Recall'),
('Agglomerative', 'Precision'),
('Agglomerative', 'Recall')]))
df_table.index.name = 'Class Name'
for f in filelist:
clazz_name = os.path.basename(f)
clazz_name = clazz_name[:clazz_name.rfind('_groundtruth.csv')]
ground_pairs = intrapairs(f)
for method in ['kmeans', 'hierarchical']:
cluster_pairs = intrapairs(
IN_DIR + '/' + clazz_name + '_' + method + '.csv')
n_common = len(ground_pairs.intersection(cluster_pairs))
precision = n_common / len(cluster_pairs)
recall = n_common / len(ground_pairs)
algo = 'KMeans' if method == 'kmeans' else 'Agglomerative'
df_table.loc[clazz_name, [(algo, 'Precision'), (algo, 'Recall')]] = [
str(round(precision * 100, 2)) + '%',
str(round(recall * 100, 2)) + '%'
]
df_table.columns = [x[0] + ' ' + x[1] for x in df_table.columns]
print(df_table.to_markdown())
if __name__ == '__main__':
main()