kse-01/search-data.py

140 lines
4.5 KiB
Python

import re
import argparse
import os
import pandas as pd
import nltk
import numpy as np
from nltk.corpus import stopwords
from gensim.similarities import SparseMatrixSimilarity, MatrixSimilarity
from gensim.models import TfidfModel, LsiModel, LdaModel
from gensim.models.doc2vec import TaggedDocument, Doc2Vec
from gensim.corpora import Dictionary
from collections import defaultdict
import coloredlogs
import logging
coloredlogs.install()
logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.INFO)
nltk.download('stopwords')
SCRIPT_DIR = os.path.abspath(os.path.dirname(__file__))
IN_DATASET = os.path.join(SCRIPT_DIR, "data.csv")
DOC2VEC_MODEL = os.path.join(SCRIPT_DIR, "doc2vec_model.dat")
# using ntlk stop words and example words for now
STOP_WORDS = set(stopwords.words('english')) \
.union(['test', 'tests', 'main', 'this', 'self'])
def find_all(regex, word):
matches = re.finditer(regex, word)
return [m.group(0).lower() for m in matches]
# https://stackoverflow.com/a/29920015
def camel_case_split(word):
return find_all('.+?(?:(?<=[a-z])(?=[A-Z])|(?<=[A-Z])(?=[A-Z][a-z])|$)', word)
def identifier_split(identifier):
return [y for x in identifier.split("_") for y in camel_case_split(x)]
def comment_split(comment):
return find_all('[A-Za-z0-9]+', comment)
def remove_stopwords(input_bow_list):
return [word for word in input_bow_list if word not in STOP_WORDS]
def get_bow(data, split_f):
if data is None or (type(data) == float and np.isnan(data)):
return []
return remove_stopwords(split_f(data))
def print_sims(corpus, query, df, dictionary):
index = SparseMatrixSimilarity(corpus, num_features=len(dictionary))
sims = index[query]
pick_top = 5
print_results(sorted(enumerate(sims), key=lambda x: x[1], reverse=True)[:pick_top], df)
def print_results(idxs_scores, df):
print("\n===== RESULTS: =====")
for idx, score in idxs_scores:
row = df.loc[idx]
comment = row["comment"]
if type(comment) != str:
desc = ""
else:
comment = re.sub(re.compile(r'[\s\n]+', re.MULTILINE), ' ', comment)
desc = "Description: {c}\n".format(c=comment)
desc = (desc[:75] + '...\n') if len(desc) > 75 else desc
print("\nSimilarity: {s:2.02f}%".format(s=score*100))
print("Python {feat}: {name}\n{desc}File: {file}\nLine: {line}" \
.format(feat=row["type"], name=row["name"], desc=desc, file=row["file"], line=row["line"]))
def build_doc2vec_model(corpus_list):
dvdocs = [TaggedDocument(text, [i]) for i, text in enumerate(corpus_list)]
model = Doc2Vec(vector_size=100, epochs=100, sample=1e-5)
model.build_vocab(dvdocs)
model.train(dvdocs, total_examples=model.corpus_count, epochs=model.epochs)
model.save(DOC2VEC_MODEL)
return model
def search(query, method):
df = pd.read_csv(IN_DATASET)
df["name_bow"] = df["name"].apply(lambda n: get_bow(n, identifier_split))
df["comment_bow"] = df["comment"].apply(lambda c: get_bow(c, comment_split))
corpus_list = []
for idx, row in df.iterrows():
document_words = row["name_bow"] + row["comment_bow"]
corpus_list.append(document_words)
query_w = get_bow(query, comment_split)
if method != "doc2vec":
dictionary = Dictionary(corpus_list)
corpus_bow = [dictionary.doc2bow(text) for text in corpus_list]
query_bow = dictionary.doc2bow(query_w)
if method == "tfidf":
tfidf = TfidfModel(corpus_bow)
print_sims(tfidf[corpus_bow], tfidf[query_bow], df, dictionary)
elif method == "freq":
print_sims(corpus_bow, query_bow, df, dictionary)
elif method == "lsi":
lsi = LsiModel(corpus_bow)
print_sims(lsi[corpus_bow], lsi[query_bow], df, dictionary)
elif method == "doc2vec":
if os.path.exists(DOC2VEC_MODEL):
model = Doc2Vec.load(DOC2VEC_MODEL)
else:
model = build_doc2vec_model(corpus_list)
dvquery = model.infer_vector(query_w)
print_results(model.dv.most_similar([dvquery], topn=5), df)
else:
raise Error("method unknown")
def main():
parser = argparse.ArgumentParser()
parser.add_argument("method", help="the method to compare similarities with", type=str)
parser.add_argument("query", help="the query to search the corpus with", type=str)
args = parser.parse_args()
search(args.query, args.method)
if __name__ == "__main__":
main()