import ast import os.path import sys from collections import defaultdict from typing import Optional, Dict, DefaultDict, Tuple, List import astunparse import frozendict import pandas as pd import tqdm import operators from operators import evaluate_condition ROOT_DIR: str = os.path.dirname(__file__) IN_DIR: str = os.path.join(ROOT_DIR, 'benchmark') OUT_DIR: str = os.path.join(ROOT_DIR, 'instrumented') SUFFIX: str = "_instrumented" class BranchTransformer(ast.NodeTransformer): branches_range: Dict[str, Tuple[int, int]] branch_num: int instrumented_name: Optional[str] in_assert: bool in_return: bool function_nodes: int def __init__(self): self.branch_num = 0 self.instrumented_name = None self.in_assert = False self.in_return = False self.branches_range = {} self.function_nodes = 0 @staticmethod def to_instrumented_name(name: str): return name + SUFFIX @staticmethod def to_original_name(name: str): assert name.endswith(SUFFIX) return name[:len(name) - len(SUFFIX)] def visit_Assert(self, ast_node): self.in_assert = True self.generic_visit(ast_node) self.in_assert = False return ast_node def visit_Return(self, ast_node): self.in_return = True self.generic_visit(ast_node) self.in_return = False return ast_node def visit_FunctionDef(self, ast_node): self.function_nodes += 1 self.instrumented_name = ast_node.name b_start = self.branch_num ast_node.name = BranchTransformer.to_instrumented_name(ast_node.name) inner_node = self.generic_visit(ast_node) self.branches_range[ast_node.name] = (b_start + 1, self.branch_num + 1) self.instrumented_name = None return inner_node def visit_Call(self, ast_node): if isinstance(ast_node.func, ast.Name) and ast_node.func.id == self.instrumented_name: ast_node.func.id = BranchTransformer.to_instrumented_name(ast_node.func.id) return ast_node def visit_Compare(self, ast_node): if ast_node.ops[0] in [ast.Is, ast.IsNot, ast.In, ast.NotIn] or self.in_assert or self.in_return: return self.generic_visit(ast_node) self.branch_num += 1 return ast.Call(func=ast.Name(evaluate_condition.__name__, ast.Load()), args=[ast.Num(self.branch_num), ast.Str(ast_node.ops[0].__class__.__name__), ast_node.left, ast_node.comparators[0]], keywords=[], starargs=None, kwargs=None) ArgType = str Arg = Tuple[str, ArgType] Params = 'frozendict.frozendict[str, any]' SignatureDict = Dict[str, List[Arg]] n_of_branches: Dict[str, Tuple[int, int]] = {} functions: SignatureDict = {} module_of: Dict[str, List[str]] = {} def instrument(source_path: str, target_path: str, save_instrumented=True) -> Tuple[int, int]: global functions, n_of_branches with open(source_path, "r") as f: source = f.read() node = ast.parse(source) # print(ast.dump(node, indent=2)) b = BranchTransformer() b.visit(node) functions_found = b.function_nodes conditions_found = b.branch_num - 1 for f_name, limits in b.branches_range.items(): n_of_branches[f_name] = limits node = ast.fix_missing_locations(node) # Make sure the line numbers are ok before printing if save_instrumented: with open(target_path, "w") as f: print(astunparse.unparse(node), file=f) current_module = sys.modules[__name__] code = compile(node, filename="", mode="exec") exec(code, current_module.__dict__) # Figure out the top level function definitions assert isinstance(node, ast.Module) top_level_f_ast: List[ast.FunctionDef] = [f for f in node.body if isinstance(f, ast.FunctionDef)] for f in top_level_f_ast: arg_types: List[Arg] = [] for arg in f.args.args: # fetch annotation type if found else fetch none # noinspection PyUnresolvedReferences arg_type = None if arg.annotation is None else arg.annotation.id arg_types.append((arg.arg, arg_type)) functions[f.name] = arg_types module_of[f.name] = os.path.splitext(os.path.normpath(os.path.relpath(source_path, ROOT_DIR)))[0].split("/") return functions_found, conditions_found def invoke(f_name: str, f_args: Params) -> any: global functions current_module = sys.modules[__name__] operators.distances_true = {} operators.distances_false = {} if f_name not in functions: raise ValueError(f"Function '{f_name}' not loaded") f_args_signature = functions[f_name] for arg_name, arg_type in f_args_signature: if arg_name not in f_args: raise ValueError(f"Required argument '{arg_name}' not provided") return getattr(current_module, f_name)(**f_args) def find_py_files(search_dir: str): for (cwd, dirs, files) in os.walk(search_dir): for file in files: if file.endswith(".py"): yield os.path.join(cwd, file) def load_benchmark(save_instrumented=True, files_count: List[str] = ()) -> Tuple[int, int, int]: to_load = set([os.path.splitext(os.path.basename(file))[0] + ".py" for file in files_count]) do_all = len(to_load) == 0 files_count = 0 functions_count = 0 conditions_count = 0 for file in tqdm.tqdm(find_py_files(IN_DIR), desc="Instrumenting"): filename = os.path.basename(file) if do_all or filename in to_load: fs, cs = instrument(file, os.path.join(OUT_DIR, filename), save_instrumented=save_instrumented) files_count += 1 functions_count += fs conditions_count += cs return files_count, functions_count, conditions_count def call_statement(f_name: str, f_args: Params) -> str: arg_list: List[str] = [] for k, v in f_args.items(): arg_list.append(f"{k}={repr(v)}") # quote strings return f"{f_name}({', '.join(arg_list)})" def get_benchmark() -> Dict[str, List[str]]: """ Returns a dictionary associated each source code file name loaded (without extension) with the list of (non-instrumented) function names defined within it """ names: DefaultDict[str, List[str]] = defaultdict(list) for f in functions: names[module_of[f][-1]].append(BranchTransformer.to_original_name(f)) return dict(names) def main(): files_count, functions_count, conditions_count = load_benchmark(save_instrumented=True) df = pd.DataFrame.from_records([ {'Type': 'Python Files', 'Number': files_count}, {'Type': 'Function Nodes', 'Number': functions_count}, {'Type': 'Comparison Nodes', 'Number': conditions_count}, ]) print(df.to_latex(index=False)) if __name__ == '__main__': main()