118 lines
3.3 KiB
Python
118 lines
3.3 KiB
Python
import argparse
|
|
from typing import Iterable, Optional
|
|
|
|
import numpy as np
|
|
import pandas as pd
|
|
import seaborn as sns
|
|
import tqdm
|
|
from matplotlib import pyplot as plt
|
|
from sklearn.manifold import TSNE
|
|
|
|
search_data = __import__('search-data')
|
|
|
|
PREFIX: str = "./"
|
|
|
|
|
|
def read_ground_truth(file_path: str, df: pd.DataFrame) -> Iterable[tuple[str, int]]:
|
|
records: list[list[str]] = []
|
|
|
|
with open(file_path) as f:
|
|
record_tmp = []
|
|
for line in f:
|
|
line = line.strip()
|
|
if line == '':
|
|
assert len(record_tmp) == 3
|
|
records.append(record_tmp)
|
|
record_tmp = []
|
|
else:
|
|
record_tmp.append(line)
|
|
|
|
if len(record_tmp) == 3:
|
|
records.append(record_tmp)
|
|
|
|
for query, name, file_name in records:
|
|
assert file_name.startswith(PREFIX)
|
|
file_name = file_name[len(PREFIX):]
|
|
|
|
row = df[(df.name == name) & (df.file == file_name)]
|
|
assert len(row) == 1
|
|
|
|
yield query, row.index[0]
|
|
|
|
|
|
def better_index(li: list[tuple[int, float]], e: int) -> Optional[int]:
|
|
for i, le in enumerate(li):
|
|
if le[0] == e:
|
|
return i
|
|
|
|
return None
|
|
|
|
|
|
def plot_df(results, query: str) -> Optional[pd.DataFrame]:
|
|
if results.vectors is not None and results.query_vector is not None:
|
|
tsne_vectors = np.array(results.vectors + [results.query_vector])
|
|
# try perplexity = 1, 1.5, 2
|
|
tsne = TSNE(n_components=2, verbose=1, perplexity=1.5, n_iter=3000)
|
|
tsne_results = tsne.fit_transform(tsne_vectors)
|
|
df = pd.DataFrame(columns=['tsne-2d-one', 'tsne-2d-two', 'query', 'is_input'])
|
|
df['tsne-2d-one'] = tsne_results[:, 0]
|
|
df['tsne-2d-two'] = tsne_results[:, 1]
|
|
df['query'] = [query] * (len(results.vectors) + 1)
|
|
df['is_input'] = (['Result'] * len(results.vectors)) + ['Input query']
|
|
return df
|
|
else:
|
|
return None
|
|
|
|
|
|
def main(method: str, file_path: str):
|
|
df = search_data.load_data()
|
|
test_set = list(read_ground_truth(file_path, df))
|
|
|
|
precision_sum = 0
|
|
recall_sum = 0
|
|
|
|
dfs = []
|
|
|
|
for query, expected in tqdm.tqdm(test_set):
|
|
search_results = search_data.search(query, method, df)
|
|
|
|
df_q = plot_df(search_results, query)
|
|
if df_q is not None:
|
|
dfs.append(df_q)
|
|
|
|
idx = better_index(search_results.indexes_scores, expected)
|
|
|
|
if idx is None:
|
|
precision = 0
|
|
recall = 0
|
|
else:
|
|
precision = 1 / (idx + 1)
|
|
recall = 1
|
|
|
|
precision_sum += precision
|
|
recall_sum += recall
|
|
|
|
print("Precision: {0:.2f}%".format(precision_sum * 100 / len(test_set)))
|
|
print("Recall: {0:.2f}%".format(recall_sum * 100 / len(test_set)))
|
|
|
|
df = pd.concat(dfs)
|
|
|
|
plt.figure(figsize=(4, 4))
|
|
ax = sns.scatterplot(
|
|
x="tsne-2d-one", y="tsne-2d-two",
|
|
hue="query",
|
|
style="is_input",
|
|
palette=sns.color_palette("husl", n_colors=10),
|
|
data=df,
|
|
legend="full",
|
|
alpha=1.0
|
|
)
|
|
plt.show()
|
|
|
|
|
|
if __name__ == '__main__':
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("method", help="the method to compare similarities with", type=str)
|
|
parser.add_argument("ground_truth_file", help="file where ground truth comes from", type=str)
|
|
args = parser.parse_args()
|
|
main(args.method, args.ground_truth_file)
|