import argparse import os.path from typing import Iterable, Optional import numpy as np import pandas as pd import seaborn as sns import tqdm from matplotlib import pyplot as plt from sklearn.manifold import TSNE search_data = __import__('search-data') TENSORFLOW_PATH_PREFIX: str = "./" OUT_DIR: str = os.path.join(os.path.dirname(__file__), "out") def read_ground_truth(file_path: str, df: pd.DataFrame) -> Iterable[tuple[str, int]]: records: list[list[str]] = [] with open(file_path) as f: record_tmp = [] for line in f: line = line.strip() if line == '': assert len(record_tmp) == 3 records.append(record_tmp) record_tmp = [] else: record_tmp.append(line) if len(record_tmp) == 3: records.append(record_tmp) for query, name, file_name in records: assert file_name.startswith(TENSORFLOW_PATH_PREFIX) file_name = file_name[len(TENSORFLOW_PATH_PREFIX):] row = df[(df.name == name) & (df.file == file_name)] assert len(row) == 1 yield query, row.index[0] def better_index(li: list[tuple[int, float]], e: int) -> Optional[int]: for i, le in enumerate(li): if le[0] == e: return i return None def plot_df(results, query: str) -> Optional[pd.DataFrame]: if results.vectors is not None and results.query_vector is not None: tsne_vectors = np.array(results.vectors + [results.query_vector]) tsne = TSNE(n_components=2, verbose=1, perplexity=1.5, n_iter=3000) tsne_results = tsne.fit_transform(tsne_vectors) df = pd.DataFrame(columns=['tsne-2d-one', 'tsne-2d-two', 'Query', 'Vector kind']) df['tsne-2d-one'] = tsne_results[:, 0] df['tsne-2d-two'] = tsne_results[:, 1] df['Query'] = [query] * (len(results.vectors) + 1) df['Vector kind'] = (['Result'] * len(results.vectors)) + ['Input query'] return df else: return None def main(method: str, file_path: str): df = search_data.load_data() test_set = list(read_ground_truth(file_path, df)) precision_sum = 0 recall_sum = 0 dfs = [] for query, expected in tqdm.tqdm(test_set): search_results = search_data.search(query, method, df) df_q = plot_df(search_results, query) if df_q is not None: dfs.append(df_q) idx = better_index(search_results.indexes_scores, expected) if idx is None: precision = 0 recall = 0 else: precision = 1 / (idx + 1) recall = 1 precision_sum += precision recall_sum += recall if not os.path.isdir(OUT_DIR): os.makedirs(OUT_DIR) output = "Precision: {0:.2f}%\nRecall: {0:.2f}%\n".format(precision_sum * 100 / len(test_set)) print(output) with open(os.path.join(OUT_DIR, "{0}_prec_recall.txt".format(method)), "w") as f: f.write(output) if len(dfs) > 0: df = pd.concat(dfs) plt.figure(figsize=(20, 16)) sns.scatterplot( x="tsne-2d-one", y="tsne-2d-two", hue="Query", style="Vector kind", palette=sns.color_palette("husl", n_colors=10), data=df, legend="full", alpha=1.0 ) plt.savefig(os.path.join(OUT_DIR, "{0}_plot.png".format(method))) if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument("method", help="the method to compare similarities with", type=str) parser.add_argument("ground_truth_file", help="file where ground truth comes from", type=str) args = parser.parse_args() main(args.method, args.ground_truth_file)