from typing import Optional import os.path import tqdm from frozendict import frozendict import ast import astunparse import sys import random from operators import compute_distances # hyperparameters ROOT_DIR: str = os.path.dirname(__file__) IN_DIR: str = os.path.join(ROOT_DIR, 'benchmark') OUT_DIR: str = os.path.join(ROOT_DIR, 'instrumented') SUFFIX: str = "_instrumented" distances_true: dict[int, int] = {} distances_false: dict[int, int] = {} # Archive of solutions archive_true_branches: dict[int, any] = {} archive_false_branches: dict[int, any] = {} class BranchTransformer(ast.NodeTransformer): branches_range: dict[str, tuple[int, int]] branch_num: int instrumented_name: Optional[str] in_assert: bool in_return: bool def __init__(self): self.branch_num = 0 self.instrumented_name = None self.in_assert = False self.in_return = False self.branches_range = {} @staticmethod def to_instrumented_name(name: str): return name + SUFFIX @staticmethod def to_original_name(name: str): assert name.endswith(SUFFIX) return name[:len(name) - len(SUFFIX)] def visit_Assert(self, ast_node): self.in_assert = True self.generic_visit(ast_node) self.in_assert = False return ast_node def visit_Return(self, ast_node): self.in_return = True self.generic_visit(ast_node) self.in_return = False return ast_node def visit_FunctionDef(self, ast_node): self.instrumented_name = ast_node.name b_start = self.branch_num ast_node.name = BranchTransformer.to_instrumented_name(ast_node.name) inner_node = self.generic_visit(ast_node) self.branches_range[ast_node.name] = (b_start, self.branch_num) self.instrumented_name = None return inner_node def visit_Call(self, ast_node): if isinstance(ast_node.func, ast.Name) and ast_node.func.id == self.instrumented_name: ast_node.func.id = BranchTransformer.to_instrumented_name(ast_node.func.id) return ast_node def visit_Compare(self, ast_node): if ast_node.ops[0] in [ast.Is, ast.IsNot, ast.In, ast.NotIn] or self.in_assert or self.in_return: return self.generic_visit(ast_node) self.branch_num += 1 return ast.Call(func=ast.Name("evaluate_condition", ast.Load()), args=[ast.Num(self.branch_num), ast.Str(ast_node.ops[0].__class__.__name__), ast_node.left, ast_node.comparators[0]], keywords=[], starargs=None, kwargs=None) def update_maps(condition_num, d_true, d_false): global distances_true, distances_false if condition_num in distances_true.keys(): distances_true[condition_num] = min(distances_true[condition_num], d_true) else: distances_true[condition_num] = d_true if condition_num in distances_false.keys(): distances_false[condition_num] = min(distances_false[condition_num], d_false) else: distances_false[condition_num] = d_false def evaluate_condition(num, op, lhs, rhs): # type: ignore if op == "In": if isinstance(lhs, str): lhs = ord(lhs) minimum = sys.maxsize for elem in rhs.keys(): distance = abs(lhs - ord(elem)) if distance < minimum: minimum = distance distance_true, distance_false = minimum, 1 if minimum == 0 else 0 else: distance_true, distance_false = compute_distances(op, lhs, rhs) update_maps(num, distance_true, distance_false) # distance == 0 equivalent to actual test by construction return distance_true == 0 ArgType = str Arg = tuple[str, ArgType] Params = frozendict[str, any] SignatureDict = dict[str, list[Arg]] n_of_branches: dict[str, tuple[int, int]] = {} functions: SignatureDict = {} module_of: dict[str, str] = {} def instrument(source_path: str, target_path: str, save_instrumented=True): global functions, n_of_branches with open(source_path, "r") as f: source = f.read() node = ast.parse(source) # print(ast.dump(node, indent=2)) b = BranchTransformer() b.visit(node) for f_name, limits in b.branches_range.items(): n_of_branches[f_name] = limits node = ast.fix_missing_locations(node) # Make sure the line numbers are ok before printing if save_instrumented: with open(target_path, "w") as f: print(astunparse.unparse(node), file=f) current_module = sys.modules[__name__] code = compile(node, filename="", mode="exec") exec(code, current_module.__dict__) # Figure out the top level function definitions assert isinstance(node, ast.Module) top_level_f_ast: list[ast.FunctionDef] = [f for f in node.body if isinstance(f, ast.FunctionDef)] for f in top_level_f_ast: if f.name in functions: raise ValueError(f"Function '{f.name}' already loaded from another file") arg_types: list[Arg] = [] for arg in f.args.args: # fetch annotation type if found else fetch none # noinspection PyUnresolvedReferences arg_type = None if arg.annotation is None else arg.annotation.id arg_types.append((arg.arg, arg_type)) functions[f.name] = arg_types module_of[f.name] = os.path.normpath(os.path.relpath(source_path, ROOT_DIR)) \ .replace(".py", "") \ .replace("/", ".") def invoke(f_name: str, f_args: Params) -> any: global functions current_module = sys.modules[__name__] if f_name not in functions: raise ValueError(f"Function '{f_name}' not loaded") f_args_signature = functions[f_name] for arg_name, arg_type in f_args_signature: if arg_name not in f_args: raise ValueError(f"Required argument '{arg_name}' not provided") return getattr(current_module, f_name)(**f_args) def find_py_files(search_dir: str): for (cwd, dirs, files) in os.walk(search_dir): for file in files: if file.endswith(".py"): yield os.path.join(cwd, file) def load_benchmark(save_instrumented=True): for file in tqdm.tqdm(find_py_files(IN_DIR), desc="Instrumenting"): instrument(file, os.path.join(OUT_DIR, os.path.basename(file)), save_instrumented=save_instrumented) def call_statement(f_name: str, f_args: Params) -> str: arg_list: list[str] = [] for k, v in f_args.items(): arg_list.append(f"{k}={repr(v)}") # quote strings return f"{f_name}({', '.join(arg_list)})" if __name__ == '__main__': load_benchmark(save_instrumented=True)