import ast import os.path import sys from collections import defaultdict from typing import Optional import astunparse import tqdm from frozendict import frozendict from operators import evaluate_condition ROOT_DIR: str = os.path.dirname(__file__) IN_DIR: str = os.path.join(ROOT_DIR, 'benchmark') OUT_DIR: str = os.path.join(ROOT_DIR, 'instrumented') SUFFIX: str = "_instrumented" class BranchTransformer(ast.NodeTransformer): branches_range: dict[str, tuple[int, int]] branch_num: int instrumented_name: Optional[str] in_assert: bool in_return: bool def __init__(self): self.branch_num = 0 self.instrumented_name = None self.in_assert = False self.in_return = False self.branches_range = {} @staticmethod def to_instrumented_name(name: str): return name + SUFFIX @staticmethod def to_original_name(name: str): assert name.endswith(SUFFIX) return name[:len(name) - len(SUFFIX)] def visit_Assert(self, ast_node): self.in_assert = True self.generic_visit(ast_node) self.in_assert = False return ast_node def visit_Return(self, ast_node): self.in_return = True self.generic_visit(ast_node) self.in_return = False return ast_node def visit_FunctionDef(self, ast_node): self.instrumented_name = ast_node.name b_start = self.branch_num ast_node.name = BranchTransformer.to_instrumented_name(ast_node.name) inner_node = self.generic_visit(ast_node) self.branches_range[ast_node.name] = (b_start + 1, self.branch_num + 1) self.instrumented_name = None return inner_node def visit_Call(self, ast_node): if isinstance(ast_node.func, ast.Name) and ast_node.func.id == self.instrumented_name: ast_node.func.id = BranchTransformer.to_instrumented_name(ast_node.func.id) return ast_node def visit_Compare(self, ast_node): if ast_node.ops[0] in [ast.Is, ast.IsNot, ast.In, ast.NotIn] or self.in_assert or self.in_return: return self.generic_visit(ast_node) self.branch_num += 1 return ast.Call(func=ast.Name(evaluate_condition.__name__, ast.Load()), args=[ast.Num(self.branch_num), ast.Str(ast_node.ops[0].__class__.__name__), ast_node.left, ast_node.comparators[0]], keywords=[], starargs=None, kwargs=None) ArgType = str Arg = tuple[str, ArgType] Params = frozendict[str, any] SignatureDict = dict[str, list[Arg]] n_of_branches: dict[str, tuple[int, int]] = {} functions: SignatureDict = {} module_of: dict[str, list[str]] = {} def instrument(source_path: str, target_path: str, save_instrumented=True): global functions, n_of_branches with open(source_path, "r") as f: source = f.read() node = ast.parse(source) # print(ast.dump(node, indent=2)) b = BranchTransformer() b.visit(node) for f_name, limits in b.branches_range.items(): n_of_branches[f_name] = limits node = ast.fix_missing_locations(node) # Make sure the line numbers are ok before printing if save_instrumented: with open(target_path, "w") as f: print(astunparse.unparse(node), file=f) current_module = sys.modules[__name__] code = compile(node, filename="", mode="exec") exec(code, current_module.__dict__) # Figure out the top level function definitions assert isinstance(node, ast.Module) top_level_f_ast: list[ast.FunctionDef] = [f for f in node.body if isinstance(f, ast.FunctionDef)] for f in top_level_f_ast: if f.name in functions: raise ValueError(f"Function '{f.name}' already loaded from another file") arg_types: list[Arg] = [] for arg in f.args.args: # fetch annotation type if found else fetch none # noinspection PyUnresolvedReferences arg_type = None if arg.annotation is None else arg.annotation.id arg_types.append((arg.arg, arg_type)) functions[f.name] = arg_types module_of[f.name] = os.path.splitext(os.path.normpath(os.path.relpath(source_path, ROOT_DIR)))[0].split("/") def invoke(f_name: str, f_args: Params) -> any: global functions current_module = sys.modules[__name__] if f_name not in functions: raise ValueError(f"Function '{f_name}' not loaded") f_args_signature = functions[f_name] for arg_name, arg_type in f_args_signature: if arg_name not in f_args: raise ValueError(f"Required argument '{arg_name}' not provided") return getattr(current_module, f_name)(**f_args) def find_py_files(search_dir: str): for (cwd, dirs, files) in os.walk(search_dir): for file in files: if file.endswith(".py"): yield os.path.join(cwd, file) def load_benchmark(save_instrumented=True, files: list[str] = ()): to_load = set([os.path.splitext(os.path.basename(file))[0] + ".py" for file in files]) do_all = len(to_load) == 0 for file in tqdm.tqdm(find_py_files(IN_DIR), desc="Instrumenting"): filename = os.path.basename(file) if do_all or filename in to_load: instrument(file, os.path.join(OUT_DIR, filename), save_instrumented=save_instrumented) def call_statement(f_name: str, f_args: Params) -> str: arg_list: list[str] = [] for k, v in f_args.items(): arg_list.append(f"{k}={repr(v)}") # quote strings return f"{f_name}({', '.join(arg_list)})" def get_benchmark() -> dict[str, list[str]]: """ Returns a dictionary associated each source code file name loaded (without extension) with the list of (non-instrumented) function names defined within it """ names: defaultdict[str, list[str]] = defaultdict(list) for f in functions: names[module_of[f][-1]].append(BranchTransformer.to_original_name(f)) return dict(names) if __name__ == '__main__': load_benchmark(save_instrumented=True)