from typing import Optional import os.path import tqdm from frozendict import frozendict import ast import astunparse import sys import random from deap import creator, base, tools, algorithms from operators import compute_distances # hyperparameters NPOP = 300 NGEN = 200 INDMUPROB = 0.05 MUPROB = 0.1 CXPROB = 0.5 TOURNSIZE = 3 LOW = -1000 UP = 1000 REPS = 10 MAX_STRING_LENGTH = 10 ROOT_DIR: str = os.path.dirname(__file__) IN_DIR: str = os.path.join(ROOT_DIR, 'benchmark') OUT_DIR: str = os.path.join(ROOT_DIR, 'instrumented') SUFFIX: str = "_instrumented" distances_true: dict[int, int] = {} distances_false: dict[int, int] = {} branches: list[int] = [1, 2, 3, 4, 5] archive_true_branches: dict[int, str] = {} archive_false_branches: dict[int, str] = {} class BranchTransformer(ast.NodeTransformer): branch_num: int instrumented_name: Optional[str] def __init__(self): self.branch_num = 0 self.instrumented_name = None @staticmethod def to_instrumented_name(name: str): return name + SUFFIX @staticmethod def to_original_name(name: str): assert name.endswith(SUFFIX) return name[:len(name) - len(SUFFIX)] def visit_Assert(self, ast_node): # Disable recursion in asserts, i.e. do not instrument assert conditions # TODO: may fail if assertion calls method (which must be renamed) return ast_node def visit_Return(self, ast_node): # Same thing for return statements # TODO: may fail if return statement calls method (which must be renamed) return ast_node def visit_FunctionDef(self, ast_node): self.instrumented_name = ast_node.name ast_node.name = BranchTransformer.to_instrumented_name(ast_node.name) inner_node = self.generic_visit(ast_node) self.instrumented_name = None return inner_node def visit_Call(self, ast_node): if isinstance(ast_node.func, ast.Name) and ast_node.func.id == self.instrumented_name: ast_node.func.id = BranchTransformer.to_instrumented_name(ast_node.func.id) return ast_node def visit_Compare(self, ast_node): if ast_node.ops[0] in [ast.Is, ast.IsNot, ast.In, ast.NotIn]: return ast_node self.branch_num += 1 return ast.Call(func=ast.Name("evaluate_condition", ast.Load()), args=[ast.Num(self.branch_num), ast.Str(ast_node.ops[0].__class__.__name__), ast_node.left, ast_node.comparators[0]], keywords=[], starargs=None, kwargs=None) def update_maps(condition_num, d_true, d_false): global distances_true, distances_false if condition_num in distances_true.keys(): distances_true[condition_num] = min(distances_true[condition_num], d_true) else: distances_true[condition_num] = d_true if condition_num in distances_false.keys(): distances_false[condition_num] = min(distances_false[condition_num], d_false) else: distances_false[condition_num] = d_false def evaluate_condition(num, op, lhs, rhs): # type: ignore if op == "In": if isinstance(lhs, str): lhs = ord(lhs) minimum = sys.maxsize for elem in rhs.keys(): distance = abs(lhs - ord(elem)) if distance < minimum: minimum = distance distance_true, distance_false = minimum, 1 if minimum == 0 else 0 else: distance_true, distance_false = compute_distances(op, lhs, rhs) update_maps(num, distance_true, distance_false) # distance == 0 equivalent to actual test by construction return distance_true == 0 def normalize(x): return x / (1.0 + x) def get_fitness_cgi(individual): x = individual[0] # Reset any distance values from previous executions global distances_true, distances_false global branches, archive_true_branches, archive_false_branches distances_true = {} distances_false = {} # TODO: fix this # Run the function under test # try: # cgi_decode_instrumented(x) # except BaseException: # pass # Sum up branch distances fitness = 0.0 for branch in branches: if branch in distances_true: if distances_true[branch] == 0 and branch not in archive_true_branches: archive_true_branches[branch] = x if branch not in archive_true_branches: fitness += normalize(distances_true[branch]) for branch in branches: if branch in distances_false: if distances_false[branch] == 0 and branch not in archive_false_branches: archive_false_branches[branch] = x if branch not in archive_false_branches: fitness += normalize(distances_false[branch]) return fitness, def random_string(): l = random.randint(0, MAX_STRING_LENGTH) s = "" for i in range(l): random_character = chr(random.randrange(32, 127)) s = s + random_character return s def crossover(individual1, individual2): parent1 = individual1[0] parent2 = individual2[0] if len(parent1) > 1 and len(parent2) > 1: pos = random.randint(1, len(parent1)) offspring1 = parent1[:pos] + parent2[pos:] offspring2 = parent2[:pos] + parent1[pos:] individual1[0] = offspring1 individual2[0] = offspring2 return individual1, individual2 def mutate(individual): chromosome = individual[0] mutated = chromosome[:] if len(mutated) > 0: prob = 1.0 / len(mutated) for pos in range(len(mutated)): if random.random() < prob: new_c = chr(random.randrange(32, 127)) mutated = mutated[:pos] + new_c + mutated[pos + 1:] individual[0] = mutated return individual, def generate(): global archive_true_branches, archive_false_branches creator.create("Fitness", base.Fitness, weights=(-1.0,)) creator.create("Individual", list, fitness=creator.Fitness) toolbox = base.Toolbox() toolbox.register("attr_str", random_string) toolbox.register("individual", tools.initRepeat, creator.Individual, toolbox.attr_str, n=1) toolbox.register("population", tools.initRepeat, list, toolbox.individual) toolbox.register("evaluate", get_fitness_cgi) toolbox.register("mate", crossover) toolbox.register("mutate", mutate) toolbox.register("select", tools.selTournament, tournsize=TOURNSIZE) coverage = [] for i in range(REPS): archive_true_branches = {} archive_false_branches = {} population = toolbox.population(n=NPOP) algorithms.eaSimple(population, toolbox, CXPROB, MUPROB, NGEN, verbose=False) cov = len(archive_true_branches) + len(archive_false_branches) print(cov, archive_true_branches, archive_false_branches) coverage.append(cov) ArgType = str Arg = tuple[str, ArgType] Params = frozendict[str, any] SignatureDict = dict[str, list[Arg]] functions: SignatureDict = {} module_of: dict[str, str] = {} def instrument(source_path: str, target_path: str, save_instrumented=True): global functions with open(source_path, "r") as f: source = f.read() node = ast.parse(source) # print(ast.dump(node, indent=2)) BranchTransformer().visit(node) node = ast.fix_missing_locations(node) # Make sure the line numbers are ok before printing if save_instrumented: with open(target_path, "w") as f: print(astunparse.unparse(node), file=f) current_module = sys.modules[__name__] code = compile(node, filename="", mode="exec") exec(code, current_module.__dict__) # Figure out the top level function definitions assert isinstance(node, ast.Module) top_level_f_ast: list[ast.FunctionDef] = [f for f in node.body if isinstance(f, ast.FunctionDef)] for f in top_level_f_ast: if f.name in functions: raise ValueError(f"Function '{f.name}' already loaded from another file") arg_types: list[Arg] = [] for arg in f.args.args: # fetch annotation type if found else fetch none # noinspection PyUnresolvedReferences arg_type = None if arg.annotation is None else arg.annotation.id arg_types.append((arg.arg, arg_type)) functions[f.name] = arg_types module_of[f.name] = os.path.normpath(os.path.relpath(source_path, ROOT_DIR)) \ .replace(".py", "") \ .replace("/", ".") def invoke(f_name: str, f_args: Params) -> any: global functions current_module = sys.modules[__name__] if f_name not in functions: raise ValueError(f"Function '{f_name}' not loaded") f_args_signature = functions[f_name] for arg_name, arg_type in f_args_signature: if arg_name not in f_args: raise ValueError(f"Required argument '{arg_name}' not provided") return getattr(current_module, f_name)(**f_args) def find_py_files(search_dir: str): for (cwd, dirs, files) in os.walk(search_dir): for file in files: if file.endswith(".py"): yield os.path.join(cwd, file) def load_benchmark(save_instrumented=True): for file in tqdm.tqdm(find_py_files(IN_DIR), desc="Instrumenting"): instrument(file, os.path.join(OUT_DIR, os.path.basename(file)), save_instrumented=save_instrumented) def call_statement(f_name: str, f_args: Params) -> str: arg_list: list[str] = [] for k, v in f_args.items(): if type(v) == str: arg_list.append(f"{k}='{v}'") # quote strings else: arg_list.append(f"{k}={v}") return f"{f_name}({', '.join(arg_list)})" if __name__ == '__main__': load_benchmark(save_instrumented=True)