kse-01/search-data.py

148 lines
4.7 KiB
Python
Raw Normal View History

2023-10-11 11:59:07 +00:00
import argparse
2023-10-23 13:42:25 +00:00
import logging
2023-10-11 11:59:07 +00:00
import os
2023-10-23 13:42:25 +00:00
import re
import coloredlogs
2023-10-11 12:35:41 +00:00
import nltk
import numpy as np
2023-10-23 13:42:25 +00:00
import pandas as pd
2023-10-11 15:49:38 +00:00
from gensim.corpora import Dictionary
2023-10-23 13:42:25 +00:00
from gensim.models import TfidfModel, LsiModel
from gensim.models.doc2vec import TaggedDocument, Doc2Vec
from gensim.similarities import SparseMatrixSimilarity
from nltk.corpus import stopwords
2023-10-11 12:35:41 +00:00
nltk.download('stopwords')
2023-10-11 11:59:07 +00:00
SCRIPT_DIR = os.path.abspath(os.path.dirname(__file__))
IN_DATASET = os.path.join(SCRIPT_DIR, "data.csv")
2023-10-16 14:36:25 +00:00
DOC2VEC_MODEL = os.path.join(SCRIPT_DIR, "doc2vec_model.dat")
2023-10-11 11:59:07 +00:00
2023-10-11 12:35:41 +00:00
# using ntlk stop words and example words for now
STOP_WORDS = set(stopwords.words('english')) \
2023-10-16 14:36:25 +00:00
.union(['test', 'tests', 'main', 'this', 'self'])
2023-10-11 12:35:41 +00:00
def find_all(regex, word):
matches = re.finditer(regex, word)
return [m.group(0).lower() for m in matches]
# https://stackoverflow.com/a/29920015
def camel_case_split(word):
return find_all('.+?(?:(?<=[a-z])(?=[A-Z])|(?<=[A-Z])(?=[A-Z][a-z])|$)', word)
def identifier_split(identifier):
return [y for x in identifier.split("_") for y in camel_case_split(x)]
def comment_split(comment):
return find_all('[A-Za-z0-9]+', comment)
def remove_stopwords(input_bow_list):
return [word for word in input_bow_list if word not in STOP_WORDS]
def get_bow(data, split_f):
if data is None or (type(data) == float and np.isnan(data)):
return []
return remove_stopwords(split_f(data))
2023-10-11 11:59:07 +00:00
2023-10-23 13:42:25 +00:00
def pick_most_similar(corpus, query, dictionary):
2023-10-11 15:49:38 +00:00
index = SparseMatrixSimilarity(corpus, num_features=len(dictionary))
sims = index[query]
2023-10-23 13:42:25 +00:00
pick_top = 5
return sorted(enumerate(sims), key=lambda x: x[1], reverse=True)[:pick_top]
2023-10-16 14:36:25 +00:00
2023-10-16 13:10:45 +00:00
2023-10-23 13:42:25 +00:00
def print_results(indexes_scores: list[tuple[int, float]], df):
2023-10-16 14:36:25 +00:00
print("\n===== RESULTS: =====")
2023-10-23 13:42:25 +00:00
for idx, score in indexes_scores:
2023-10-11 15:49:38 +00:00
row = df.loc[idx]
2023-10-23 13:42:25 +00:00
2023-10-16 14:36:25 +00:00
comment = row["comment"]
if type(comment) != str:
desc = ""
else:
comment = re.sub(re.compile(r'[\s\n]+', re.MULTILINE), ' ', comment)
desc = "Description: {c}\n".format(c=comment)
desc = (desc[:75] + '...\n') if len(desc) > 75 else desc
2023-10-23 13:42:25 +00:00
print("\nSimilarity: {s:2.02f}%".format(s=score * 100))
2023-10-16 14:36:25 +00:00
print("Python {feat}: {name}\n{desc}File: {file}\nLine: {line}" \
.format(feat=row["type"], name=row["name"], desc=desc, file=row["file"], line=row["line"]))
def build_doc2vec_model(corpus_list):
dvdocs = [TaggedDocument(text, [i]) for i, text in enumerate(corpus_list)]
model = Doc2Vec(vector_size=100, epochs=100, sample=1e-5)
model.build_vocab(dvdocs)
model.train(dvdocs, total_examples=model.corpus_count, epochs=model.epochs)
model.save(DOC2VEC_MODEL)
return model
2023-10-11 15:49:38 +00:00
2023-10-11 12:35:41 +00:00
2023-10-23 13:42:25 +00:00
def load_data() -> pd.DataFrame:
df = pd.read_csv(IN_DATASET, index_col=0)
2023-10-11 15:49:38 +00:00
df["name_bow"] = df["name"].apply(lambda n: get_bow(n, identifier_split))
df["comment_bow"] = df["comment"].apply(lambda c: get_bow(c, comment_split))
2023-10-23 13:42:25 +00:00
return df
2023-10-11 15:49:38 +00:00
2023-10-23 13:42:25 +00:00
def search(query: str, method: str, df: pd.DataFrame) -> list[tuple[int, float]]:
2023-10-11 15:49:38 +00:00
corpus_list = []
for idx, row in df.iterrows():
document_words = row["name_bow"] + row["comment_bow"]
corpus_list.append(document_words)
2023-10-16 13:10:45 +00:00
query_w = get_bow(query, comment_split)
2023-10-23 13:42:25 +00:00
dictionary = None
corpus_bow = None
query_bow = None
2023-10-16 14:36:25 +00:00
if method != "doc2vec":
dictionary = Dictionary(corpus_list)
corpus_bow = [dictionary.doc2bow(text) for text in corpus_list]
query_bow = dictionary.doc2bow(query_w)
2023-10-23 13:42:25 +00:00
2023-10-11 15:49:38 +00:00
if method == "tfidf":
tfidf = TfidfModel(corpus_bow)
2023-10-23 13:42:25 +00:00
return pick_most_similar(tfidf[corpus_bow], tfidf[query_bow], dictionary)
2023-10-11 15:49:38 +00:00
elif method == "freq":
2023-10-23 13:42:25 +00:00
return pick_most_similar(corpus_bow, query_bow, dictionary)
2023-10-11 15:49:38 +00:00
elif method == "lsi":
lsi = LsiModel(corpus_bow)
2023-10-23 13:42:25 +00:00
return pick_most_similar(lsi[corpus_bow], lsi[query_bow], dictionary)
2023-10-16 13:10:45 +00:00
elif method == "doc2vec":
2023-10-16 14:36:25 +00:00
if os.path.exists(DOC2VEC_MODEL):
model = Doc2Vec.load(DOC2VEC_MODEL)
else:
model = build_doc2vec_model(corpus_list)
2023-10-23 13:42:25 +00:00
dv_query = model.infer_vector(query_w)
return model.dv.most_similar([dv_query], topn=5)
2023-10-16 13:10:45 +00:00
else:
2023-10-23 13:42:25 +00:00
raise ValueError("method unknown")
2023-10-11 11:59:07 +00:00
def main():
parser = argparse.ArgumentParser()
2023-10-11 15:49:38 +00:00
parser.add_argument("method", help="the method to compare similarities with", type=str)
2023-10-11 11:59:07 +00:00
parser.add_argument("query", help="the query to search the corpus with", type=str)
args = parser.parse_args()
2023-10-23 13:42:25 +00:00
df = load_data()
indexes_scores = search(args.query, args.method, df)
print_results(indexes_scores, df)
2023-10-11 11:59:07 +00:00
if __name__ == "__main__":
2023-10-23 13:42:25 +00:00
coloredlogs.install()
logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.INFO)
2023-10-11 11:59:07 +00:00
main()