kse-01/prec-recall.py

150 lines
4.4 KiB
Python
Raw Normal View History

2023-10-23 13:42:25 +00:00
import argparse
2023-10-25 13:42:58 +00:00
import os.path
2023-10-23 13:42:25 +00:00
from typing import Iterable, Optional
2023-10-25 13:10:47 +00:00
import numpy as np
2023-10-23 13:42:25 +00:00
import pandas as pd
2023-10-25 13:10:47 +00:00
import seaborn as sns
import tqdm
from matplotlib import pyplot as plt
from sklearn.manifold import TSNE
2023-10-23 13:42:25 +00:00
search_data = __import__('search-data')
2023-10-25 13:42:58 +00:00
TENSORFLOW_PATH_PREFIX: str = "./"
OUT_DIR: str = os.path.join(os.path.dirname(__file__), "out")
2023-10-23 13:42:25 +00:00
def read_ground_truth(file_path: str, df: pd.DataFrame) -> Iterable[tuple[str, int]]:
records: list[list[str]] = []
with open(file_path) as f:
record_tmp = []
for line in f:
line = line.strip()
if line == '':
assert len(record_tmp) == 3
records.append(record_tmp)
record_tmp = []
else:
record_tmp.append(line)
if len(record_tmp) == 3:
records.append(record_tmp)
for query, name, file_name in records:
2023-10-25 13:42:58 +00:00
assert file_name.startswith(TENSORFLOW_PATH_PREFIX)
file_name = file_name[len(TENSORFLOW_PATH_PREFIX):]
2023-10-23 13:42:25 +00:00
row = df[(df.name == name) & (df.file == file_name)]
assert len(row) == 1
yield query, row.index[0]
def better_index(li: list[tuple[int, float]], e: int) -> Optional[int]:
for i, le in enumerate(li):
if le[0] == e:
return i
return None
2023-10-25 13:10:47 +00:00
def plot_df(results, query: str) -> Optional[pd.DataFrame]:
if results.vectors is not None and results.query_vector is not None:
tsne_vectors = np.array(results.vectors + [results.query_vector])
2023-11-07 14:07:15 +00:00
tsne = TSNE(n_components=2, perplexity=1, n_iter=3000)
2023-10-25 13:10:47 +00:00
tsne_results = tsne.fit_transform(tsne_vectors)
2023-10-25 13:42:58 +00:00
df = pd.DataFrame(columns=['tsne-2d-one', 'tsne-2d-two', 'Query', 'Vector kind'])
2023-10-25 13:10:47 +00:00
df['tsne-2d-one'] = tsne_results[:, 0]
df['tsne-2d-two'] = tsne_results[:, 1]
2023-10-25 13:42:58 +00:00
df['Query'] = [query] * (len(results.vectors) + 1)
df['Vector kind'] = (['Result'] * len(results.vectors)) + ['Input query']
2023-10-25 13:10:47 +00:00
return df
else:
return None
2023-11-07 10:48:00 +00:00
def evaluate(method_name: str, file_path: str) -> tuple[float, float]:
2023-10-23 13:42:25 +00:00
df = search_data.load_data()
test_set = list(read_ground_truth(file_path, df))
precision_sum = 0
recall_sum = 0
2023-10-25 13:10:47 +00:00
dfs = []
for query, expected in tqdm.tqdm(test_set):
2023-11-07 10:48:00 +00:00
search_results = search_data.search(query, method_name, df)
2023-10-25 13:10:47 +00:00
df_q = plot_df(search_results, query)
if df_q is not None:
dfs.append(df_q)
idx = better_index(search_results.indexes_scores, expected)
2023-10-23 13:42:25 +00:00
if idx is None:
precision = 0
recall = 0
else:
precision = 1 / (idx + 1)
recall = 1
precision_sum += precision
recall_sum += recall
2023-10-25 13:42:58 +00:00
if not os.path.isdir(OUT_DIR):
os.makedirs(OUT_DIR)
2023-11-07 10:48:00 +00:00
precision = precision_sum * 100 / len(test_set)
recall = recall_sum * 100 / len(test_set)
output = "Precision: {0:.2f}%\nRecall: {1:.2f}%\n".format(precision, recall)
2023-10-25 13:42:58 +00:00
print(output)
2023-11-07 10:48:00 +00:00
with open(os.path.join(OUT_DIR, "{0}_prec_recall.txt".format(method_name)), "w") as f:
2023-10-25 13:42:58 +00:00
f.write(output)
if len(dfs) > 0:
df = pd.concat(dfs)
2023-11-07 11:35:27 +00:00
plt.figure(figsize=(12, 10))
2023-10-25 13:42:58 +00:00
sns.scatterplot(
x="tsne-2d-one", y="tsne-2d-two",
hue="Query",
style="Vector kind",
palette=sns.color_palette("husl", n_colors=10),
data=df,
legend="full",
alpha=1.0
)
2023-11-07 10:48:00 +00:00
plt.savefig(os.path.join(OUT_DIR, "{0}_plot.png".format(method_name)))
2023-10-25 13:10:47 +00:00
2023-11-07 10:48:00 +00:00
return precision, recall
def main():
methods = ["tfidf", "freq", "lsi", "doc2vec"]
2023-10-23 13:42:25 +00:00
parser = argparse.ArgumentParser()
2023-11-07 10:48:00 +00:00
parser.add_argument("method", help="the method to compare similarities with", type=str, choices=methods + ["all"])
2023-10-23 13:42:25 +00:00
parser.add_argument("ground_truth_file", help="file where ground truth comes from", type=str)
args = parser.parse_args()
2023-11-07 10:48:00 +00:00
if args.method == "all":
df = pd.DataFrame(columns=["Engine", "Average Precision", "Average Recall"])
for i, method in enumerate(methods):
print(f"Applying method {method}:")
precision, recall = evaluate(method, args.ground_truth_file)
df.loc[i, "Engine"] = method
df.loc[i, "Average Precision"] = f"{precision:.2f}%"
df.loc[i, "Average Recall"] = f"{recall:.2f}%"
print(df.to_markdown(index=False))
else:
evaluate(args.method, args.ground_truth_file)
if __name__ == '__main__':
main()