kse-01/prec-recall.py

77 lines
2.1 KiB
Python
Raw Normal View History

2023-10-23 13:42:25 +00:00
import argparse
from typing import Iterable, Optional
import pandas as pd
search_data = __import__('search-data')
PREFIX: str = "./"
def read_ground_truth(file_path: str, df: pd.DataFrame) -> Iterable[tuple[str, int]]:
records: list[list[str]] = []
with open(file_path) as f:
record_tmp = []
for line in f:
line = line.strip()
if line == '':
assert len(record_tmp) == 3
records.append(record_tmp)
record_tmp = []
else:
record_tmp.append(line)
if len(record_tmp) == 3:
records.append(record_tmp)
for query, name, file_name in records:
assert file_name.startswith(PREFIX)
file_name = file_name[len(PREFIX):]
row = df[(df.name == name) & (df.file == file_name)]
assert len(row) == 1
yield query, row.index[0]
def better_index(li: list[tuple[int, float]], e: int) -> Optional[int]:
for i, le in enumerate(li):
if le[0] == e:
return i
return None
def main(method: str, file_path: str):
df = search_data.load_data()
test_set = list(read_ground_truth(file_path, df))
precision_sum = 0
recall_sum = 0
for query, expected in test_set:
indexes_values: list[tuple[int, float]] = search_data.search(query, method, df)
idx = better_index(indexes_values, expected)
if idx is None:
precision = 0
recall = 0
else:
precision = 1 / (idx + 1)
recall = 1
precision_sum += precision
recall_sum += recall
print("Precision: {0:.2f}%".format(precision_sum * 100 / len(test_set)))
print("Recall: {0:.2f}%".format(recall_sum * 100 / len(test_set)))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("method", help="the method to compare similarities with", type=str)
parser.add_argument("ground_truth_file", help="file where ground truth comes from", type=str)
args = parser.parse_args()
main(args.method, args.ground_truth_file)