#!/usr/bin/env python3
import numpy as np
import glob
import os
import pandas as pd

DIR: str = os.path.dirname(os.path.realpath(__file__))
IN_DIR: str = DIR + '/clustering'
OUT_DIR: str = DIR + ''


def intrapairs(path: str) -> set[set[str, str]]:
    df = pd.read_csv(path)
    clusters: list[list[str]] = df.groupby(
        'cluster').agg(list).iloc[:, 0].values

    intrapairs: set[set[str]] = set()  # inner sets always contain 2 elements
    for cluster in clusters:
        for i, e1 in enumerate(cluster):
            for j in range(i + 1, len(cluster)):
                e2 = cluster[j]
                intrapairs.add(frozenset((e1, e2,)))
    return intrapairs


def main():
    filelist = glob.glob(IN_DIR + '/*_groundtruth.csv')
    for f in filelist:
        clazz_name = os.path.basename(f)
        clazz_name = clazz_name[:clazz_name.rfind('_groundtruth.csv')]
        print(clazz_name)

        ground_pairs = intrapairs(f)
        for method in ['kmeans', 'hierarchical']:
            cluster_pairs = intrapairs(
                IN_DIR + '/' + clazz_name + '_' + method + '.csv')

            n_common = len(ground_pairs.intersection(cluster_pairs))
            precision = n_common / len(cluster_pairs)
            recall = n_common / len(ground_pairs)

            print(method + " precision: " + str(precision))
            print(method + " recall: " + str(recall))

        print()


if __name__ == '__main__':
    main()