diff --git a/sb_cgi_decode.py b/instrument.py similarity index 82% rename from sb_cgi_decode.py rename to instrument.py index a60274c..ac8b602 100644 --- a/sb_cgi_decode.py +++ b/instrument.py @@ -6,7 +6,7 @@ import astunparse import sys import random from deap import creator, base, tools, algorithms -from instrumentor import compute_distances +from operators import compute_distances # hyperparameters NPOP = 300 @@ -22,6 +22,7 @@ MAX_STRING_LENGTH = 10 IN_DIR: str = os.path.join(os.path.dirname(__file__), 'benchmark') OUT_DIR: str = os.path.join(os.path.dirname(__file__), 'instrumented') +SUFFIX: str = "_instrumented" distances_true: dict[int, int] = {} distances_false: dict[int, int] = {} @@ -30,10 +31,6 @@ archive_true_branches: dict[int, str] = {} archive_false_branches: dict[int, str] = {} -def cgi_decode_instrumented(s: str) -> str: - return "" # make mypy happy - - class BranchTransformer(ast.NodeTransformer): branch_num: int instrumented_name: Optional[str] @@ -44,7 +41,12 @@ class BranchTransformer(ast.NodeTransformer): @staticmethod def to_instrumented_name(name: str): - return name + "_instrumented" + return name + SUFFIX + + @staticmethod + def to_original_name(name: str): + assert name.endswith(SUFFIX) + return name[:len(name) - len(SUFFIX)] def visit_Assert(self, ast_node): # Disable recursion in asserts, i.e. do not instrument assert conditions @@ -128,11 +130,12 @@ def get_fitness_cgi(individual): distances_true = {} distances_false = {} + # TODO: fix this # Run the function under test - try: - cgi_decode_instrumented(x) - except BaseException: - pass + # try: + # cgi_decode_instrumented(x) + # except BaseException: + # pass # Sum up branch distances fitness = 0.0 @@ -215,13 +218,13 @@ def generate(): ArgType = str Arg = tuple[str, ArgType] +Params = dict[str, any] SignatureDict = dict[str, list[Arg]] - functions: SignatureDict = {} -def instrument(source_path: str, target_path: str): +def instrument(source_path: str, target_path: str, save_instrumented=True): global functions with open(source_path, "r") as f: @@ -232,8 +235,9 @@ def instrument(source_path: str, target_path: str): BranchTransformer().visit(node) node = ast.fix_missing_locations(node) # Make sure the line numbers are ok before printing - with open(target_path, "w") as f: - print(astunparse.unparse(node), file=f) + if save_instrumented: + with open(target_path, "w") as f: + print(astunparse.unparse(node), file=f) current_module = sys.modules[__name__] code = compile(node, filename="", mode="exec") @@ -259,24 +263,20 @@ def instrument(source_path: str, target_path: str): functions[f.name] = arg_types -def invoke_signature(fun_name: str, arg_values: dict[str]) -> any: +def invoke_signature(f_name: str, f_args: Params) -> any: global functions current_module = sys.modules[__name__] - if fun_name not in functions: - raise ValueError(f"Function '{fun_name}' not loaded") + if f_name not in functions: + raise ValueError(f"Function '{f_name}' not loaded") - args = functions[fun_name] - for arg_name, arg_type in args: - if arg_name not in arg_values: + f_args_signature = functions[f_name] + for arg_name, arg_type in f_args_signature: + if arg_name not in f_args: raise ValueError(f"Required argument '{arg_name}' not provided") - - arg_str = ",".join([f"{k}={v}" for k, v in arg_values.items()]) - print(f"Calling {fun_name}({arg_str})") - - return getattr(current_module, fun_name)(**arg_values) + return getattr(current_module, f_name)(**f_args) def find_py_files(search_dir: str): @@ -286,24 +286,41 @@ def find_py_files(search_dir: str): yield os.path.join(cwd, file) -def main(): +def load_benchmark(save_instrumented=True): + for file in find_py_files(IN_DIR): + instrument(file, os.path.join(OUT_DIR, os.path.basename(file)), save_instrumented=save_instrumented) + + +def run_all_example(): global functions - for file in find_py_files(IN_DIR): - instrument(file, os.path.join(OUT_DIR, os.path.basename(file))) - - for function, arg_signatures in functions.items(): + for f_name, f_args_signature in functions.items(): args = {} - for arg_name, arg_type in arg_signatures: + for arg_name, arg_type in f_args_signature: + # Generate some dummy values appropriate for each type + if arg_type == 'int': args[arg_name] = 42 elif arg_type == 'str': args[arg_name] = 'hello world' else: - args[arg_name] = None + raise ValueError(f"Arg type '{arg_type}' for '{arg_name}' not supported") - invoke_signature(function, args) + out = invoke_signature(f_name, args) + print(call_statement(f_name, args), "=", out) + + +def call_statement(f_name: str, f_args: Params, ) -> str: + arg_list: list[str] = [] + for k, v in f_args.items(): + if type(v) == str: + arg_list.append(f"{k}='{v}'") # quote strings + else: + arg_list.append(f"{k}={v}") + + return f"{f_name}({', '.join(arg_list)})" if __name__ == '__main__': - main() + load_benchmark(save_instrumented=True) + run_all_example() diff --git a/instrumentor.py b/operators.py similarity index 85% rename from instrumentor.py rename to operators.py index 3f865f4..fce8d76 100644 --- a/instrumentor.py +++ b/operators.py @@ -25,26 +25,6 @@ class CmpOp(Generic[T]): self.false_dist = false_dist -# @dataclass -# class InstrState: -# min_true_dist: Optional[int] -# min_false_dist: Optional[int] -# -# def __init__(self): -# self.min_true_dist = None -# self.min_false_dist = None -# -# def update(self, op: CmpOp[U], lhs: U, rhs: U): -# true_dist = op.true_dist(lhs, rhs) -# self.min_true_dist = true_dist if self.min_true_dist is None else min(true_dist, self.min_true_dist) -# -# false_dist = op.false_dist(lhs, rhs) -# self.min_false_dist = false_dist if self.min_false_dist is None else min(false_dist, self.min_false_dist) -# -# -# instrumentation_states: defaultdict[int, InstrState] = defaultdict(InstrState) - - # Operands for these must both be integers or strings of length 1 int_str_ops: list[CmpOp[int | str]] = [ CmpOp(operator='<',