#!/usr/bin/env python3
import javalang
import os
import pandas as pd
import glob
import warnings


DIR: str = os.path.dirname(os.path.realpath(__file__))
SOURCES: str = DIR + '/xerces2/src'
OUT_DIR: str = DIR + '/feature_vectors'
IN_DIR: str = DIR + '/god_classes'


def clean_output():
    filelist = glob.glob(OUT_DIR + '/*.csv')
    for f in filelist:
        os.remove(f)


def get_fields(java_class: javalang.tree.ClassDeclaration) -> set[str]:
    names = set()
    for f in java_class.fields:
        names.add(f.declarators[0].name)
    return names


def get_methods(java_class: javalang.tree.ClassDeclaration) -> set[str]:
    names = set()
    for f in java_class.methods:
        names.add(f.name)
    return names


def get_fields_accessed_by_method(method: javalang.tree.MethodDeclaration) -> set[str]:
    nodes = set()
    for _, node in method.filter(javalang.tree.MemberReference):
        if node.qualifier is None or node.qualifier == '':
            nodes.add(node.member)
        else:
            # if a MemberReference includes a non empty qualifier (e.g., a.x),
            # consider the qualifier (a), not the member (x)˝
            nodes.add(node.qualifier)
    return nodes


def get_methods_accessed_by_method(method: javalang.tree.MethodDeclaration, methods: set[str]) -> set[str]:
    nodes = set()
    for _, node in method.filter(javalang.tree.MethodInvocation):
        if node.qualifier is None or node.qualifier == '' and node in methods:
            nodes.add(node.member)
    return nodes


def parse(path: str):
    # Get the AST of the file
    with open(path) as file:
        data = file.read()
    tree = javalang.parse.parse(data)

    # Fetch package name from package declaration
    # if node is missing, assuming default package ('')
    package_name = ''
    for _, node in tree.filter(javalang.tree.PackageDeclaration):
        package_name = node.name
        break

    for _, node in tree.filter(javalang.tree.ClassDeclaration):

        # consider only the class matching the input file name, to skip inner classes
        if path.endswith(node.name + '.java'):
            fqdn = package_name + '.' + node.name

            fields = get_fields(node)
            methods = get_methods(node)

            df = pd.DataFrame(columns=sorted(fields.union(methods)), dtype=int)
            for m in node.methods:
                # make sure method is included in csv file
                df.loc[m.name, :] = 0

                m_fields = get_fields_accessed_by_method(m)
                m_methods = get_methods_accessed_by_method(m, methods)

                for member in m_fields.union(m_methods):
                    df.loc[m.name, member] = 1
            df = df.fillna(0)

            for i in df.columns:
                df[[i]] = df[[i]].astype(int)

            df.to_csv(OUT_DIR + '/' + fqdn + '.csv')
            break


def main():
    warnings.simplefilter(action='ignore', category=pd.errors.PerformanceWarning)
    clean_output()

    df = pd.read_csv(IN_DIR + '/god_classes.csv')

    for clazz in df['class_name'].to_list():
        clazz_path = SOURCES + '/' + clazz.replace('.', '/') + '.java'
        print(clazz_path)
        parse(clazz_path)


if __name__ == '__main__':
    main()