From 06beb66d5023f36e61348e3ae4eb98afdaf8a525 Mon Sep 17 00:00:00 2001 From: Claudio Maggioni Date: Mon, 16 Oct 2023 15:10:45 +0200 Subject: [PATCH] wip word2vec --- search-data.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/search-data.py b/search-data.py index 526a9732..aff08089 100644 --- a/search-data.py +++ b/search-data.py @@ -7,6 +7,7 @@ import numpy as np from nltk.corpus import stopwords from gensim.similarities import SparseMatrixSimilarity, MatrixSimilarity from gensim.models import TfidfModel, LsiModel, LdaModel +from gensim.models.doc2vec import TaggedDocument, Doc2Vec from gensim.corpora import Dictionary from collections import defaultdict @@ -53,7 +54,10 @@ def print_sims(corpus, query, df, dictionary): sims = index[query] pick_top = 5 - for idx, score in sorted(enumerate(sims), key=lambda x: x[1], reverse=True)[:pick_top]: + print_results(sorted(enumerate(sims), key=lambda x: x[1], reverse=True)[:pick_top]) + +def print_results(idxs_scores, df): + for idx, score in idxs_scores: row = df.loc[idx] print("Similarity: {s:2.02f}%".format(s=score*100)) print("Python {feat}: {name}\nFile: {file}\nLine: {line}\n" \ @@ -72,7 +76,8 @@ def search(query, method): dictionary = Dictionary(corpus_list) corpus_bow = [dictionary.doc2bow(text) for text in corpus_list] - query_bow = dictionary.doc2bow(get_bow(query, comment_split)) + query_w = get_bow(query, comment_split) + query_bow = dictionary.doc2bow(query_w) if method == "tfidf": tfidf = TfidfModel(corpus_bow) @@ -82,6 +87,15 @@ def search(query, method): elif method == "lsi": lsi = LsiModel(corpus_bow) print_sims(lsi[corpus_bow], lsi[query_bow], df, dictionary) + elif method == "doc2vec": + dvdocs = [TaggedDocument(bow, [i]) for i, bow in enumerate(corpus_bow)] + model = Doc2Vec(vector_size=50, min_count=2, epochs=100) + model.build_vocab(dvdocs) + model.train(dvdocs, total_examples=model.corpus_count, epochs=model.epochs) + dvquery = model.infer_vector(query_w) + print_results(model.dv.most_similar([dvquery], topn=5), df) + else: + raise Error("method unknown") def main():