diff --git a/instrumented/anagram_check.py b/instrumented/anagram_check.py index ed5fd3f..616bd1f 100644 --- a/instrumented/anagram_check.py +++ b/instrumented/anagram_check.py @@ -1,6 +1,6 @@ -def anagram_check(s1: str, s2: str) -> bool: +def anagram_check_instrumented(s1: str, s2: str) -> bool: if (evaluate_condition(1, 'Eq', len(s1), 1) and evaluate_condition(2, 'Eq', len(s2), 1)): return (s1 == s2) if evaluate_condition(3, 'NotEq', len(s1), len(s2)): diff --git a/instrumented/caesar_cipher.py b/instrumented/caesar_cipher.py index e29eb11..ab7e9df 100644 --- a/instrumented/caesar_cipher.py +++ b/instrumented/caesar_cipher.py @@ -1,6 +1,6 @@ -def encrypt(strng: str, key: int) -> str: +def encrypt_instrumented(strng: str, key: int) -> str: assert (0 < key <= 94) encrypted = '' for x in strng: @@ -10,7 +10,7 @@ def encrypt(strng: str, key: int) -> str: encrypted = (encrypted + chr(indx)) return encrypted -def decrypt(strng: str, key: int) -> str: +def decrypt_instrumented(strng: str, key: int) -> str: assert (0 < key <= 94) decrypted = '' for x in strng: diff --git a/instrumented/check_armstrong.py b/instrumented/check_armstrong.py index 2dabeba..131de56 100644 --- a/instrumented/check_armstrong.py +++ b/instrumented/check_armstrong.py @@ -1,6 +1,6 @@ -def check_armstrong(n: int) -> bool: +def check_armstrong_instrumented(n: int) -> bool: assert (n >= 0) if (evaluate_condition(1, 'Eq', n, 0) or evaluate_condition(2, 'Eq', n, 1)): return True diff --git a/instrumented/common_divisor_count.py b/instrumented/common_divisor_count.py index 43213d6..c6f7db2 100644 --- a/instrumented/common_divisor_count.py +++ b/instrumented/common_divisor_count.py @@ -1,7 +1,7 @@ '\nThe function takes two integers as input and return the number of common divisors of\nthat pair\n' -def cd_count(a: int, b: int) -> int: +def cd_count_instrumented(a: int, b: int) -> int: if (evaluate_condition(1, 'Eq', a, 0) or evaluate_condition(2, 'Eq', b, 0)): return 2 a = (((- 1) * a) if evaluate_condition(3, 'Lt', a, 0) else a) diff --git a/instrumented/exponentiation.py b/instrumented/exponentiation.py index 72d310a..0a1e421 100644 --- a/instrumented/exponentiation.py +++ b/instrumented/exponentiation.py @@ -1,6 +1,6 @@ -def exponentiation(baseNumber: int, power: int) -> float: +def exponentiation_instrumented(baseNumber: int, power: int) -> float: assert (not ((baseNumber == 0) or (power <= 0))) answer = None if evaluate_condition(1, 'Gt', power, 1): diff --git a/instrumented/gcd.py b/instrumented/gcd.py index 955435d..7b39772 100644 --- a/instrumented/gcd.py +++ b/instrumented/gcd.py @@ -1,6 +1,6 @@ -def gcd(a: int, b: int) -> int: +def gcd_instrumented(a: int, b: int) -> int: assert ((a > 0) and (b > 0)) if (evaluate_condition(1, 'Eq', a, 1) or evaluate_condition(2, 'Eq', b, 1)): return 1 diff --git a/instrumented/longest_substring.py b/instrumented/longest_substring.py index 66587b9..d7c433b 100644 --- a/instrumented/longest_substring.py +++ b/instrumented/longest_substring.py @@ -1,6 +1,6 @@ -def longest_sorted_substr(s: str) -> str: +def longest_sorted_substr_instrumented(s: str) -> str: count = 0 max_count = 0 end_position = 0 diff --git a/instrumented/rabin_karp.py b/instrumented/rabin_karp.py index 0903173..b4a7600 100644 --- a/instrumented/rabin_karp.py +++ b/instrumented/rabin_karp.py @@ -1,6 +1,6 @@ -def rabin_karp_search(pat: str, txt: str) -> list: +def rabin_karp_search_instrumented(pat: str, txt: str) -> list: assert (len(pat) <= len(txt)) d = 2560 q = 101 diff --git a/instrumented/railfence_cipher.py b/instrumented/railfence_cipher.py index ceee782..4723811 100644 --- a/instrumented/railfence_cipher.py +++ b/instrumented/railfence_cipher.py @@ -1,6 +1,6 @@ -def railencrypt(st: str, k: int) -> str: +def railencrypt_instrumented(st: str, k: int) -> str: assert (k > 1) c = 0 x = 0 @@ -25,7 +25,7 @@ def railencrypt(st: str, k: int) -> str: result.append(chr(m[i][j])) return ''.join(result) -def raildecrypt(st: str, k: int) -> str: +def raildecrypt_instrumented(st: str, k: int) -> str: assert (k > 1) (c, x) = (0, 0) m = [([0] * len(st)) for i in range(k)] diff --git a/instrumented/zellers_birthday.py b/instrumented/zellers_birthday.py index 5b71b41..ce39468 100644 --- a/instrumented/zellers_birthday.py +++ b/instrumented/zellers_birthday.py @@ -1,6 +1,6 @@ -def zeller(d: int, m: int, y: int) -> str: +def zeller_instrumented(d: int, m: int, y: int) -> str: assert (abs(d) >= 1) assert (abs(m) >= 1) assert ((0 <= abs(y) <= 99) or (1000 <= abs(y) <= 3000)) diff --git a/instrumentor.py b/instrumentor.py index d8f2e60..3f865f4 100644 --- a/instrumentor.py +++ b/instrumentor.py @@ -87,7 +87,7 @@ def int_str_check(a: any, b: any) -> bool: return True if type(a) != str or type(b) != str: return False - return len(a) == 1 or len(b) == 1 + return len(a) == 1 and len(b) == 1 def int_str_convert(x: int | str) -> int: @@ -135,7 +135,7 @@ def compute_distances(name: str, lhs: any, rhs: any) -> tuple[int, int]: if name not in str_by_name: raise ValueError(f"'{name}' is not a valid CmpOp name for 'str' operators") - op = int_str_by_name[name] + op = str_by_name[name] return op.true_dist(lhs, rhs), op.false_dist(lhs, rhs) raise ValueError(f"'{lhs}' and '{rhs}' are not suitable for both 'int_str' and 'str' operators") diff --git a/sb_cgi_decode.py b/sb_cgi_decode.py index 1111944..a60274c 100644 --- a/sb_cgi_decode.py +++ b/sb_cgi_decode.py @@ -56,6 +56,7 @@ class BranchTransformer(ast.NodeTransformer): def visit_FunctionDef(self, ast_node): self.instrumented_name = ast_node.name + ast_node.name = BranchTransformer.to_instrumented_name(ast_node.name) inner_node = self.generic_visit(ast_node) self.instrumented_name = None return inner_node @@ -212,12 +213,22 @@ def generate(): coverage.append(cov) +ArgType = str +Arg = tuple[str, ArgType] +SignatureDict = dict[str, list[Arg]] + + +functions: SignatureDict = {} + + def instrument(source_path: str, target_path: str): + global functions + with open(source_path, "r") as f: source = f.read() node = ast.parse(source) - print(ast.dump(node, indent=2)) + # print(ast.dump(node, indent=2)) BranchTransformer().visit(node) node = ast.fix_missing_locations(node) # Make sure the line numbers are ok before printing @@ -226,7 +237,46 @@ def instrument(source_path: str, target_path: str): current_module = sys.modules[__name__] code = compile(node, filename="", mode="exec") - exec(code, current_module.__dict__) # try: cgi_decode_instrumented("a%20%32"), print distances_true + exec(code, current_module.__dict__) + + # Figure out the top level function definitions + assert isinstance(node, ast.Module) + top_level_f_ast: list[ast.FunctionDef] = [f for f in node.body if isinstance(f, ast.FunctionDef)] + + for f in top_level_f_ast: + if f.name in functions: + raise ValueError(f"Function '{f.name}' already loaded from another file") + + arg_types: list[Arg] = [] + + for arg in f.args.args: + # fetch annotation type if found else fetch none + # noinspection PyUnresolvedReferences + arg_type = None if arg.annotation is None else arg.annotation.id + + arg_types.append((arg.arg, arg_type)) + + functions[f.name] = arg_types + + +def invoke_signature(fun_name: str, arg_values: dict[str]) -> any: + global functions + + current_module = sys.modules[__name__] + + if fun_name not in functions: + raise ValueError(f"Function '{fun_name}' not loaded") + + args = functions[fun_name] + for arg_name, arg_type in args: + if arg_name not in arg_values: + raise ValueError(f"Required argument '{arg_name}' not provided") + + + arg_str = ",".join([f"{k}={v}" for k, v in arg_values.items()]) + print(f"Calling {fun_name}({arg_str})") + + return getattr(current_module, fun_name)(**arg_values) def find_py_files(search_dir: str): @@ -237,9 +287,22 @@ def find_py_files(search_dir: str): def main(): + global functions + for file in find_py_files(IN_DIR): instrument(file, os.path.join(OUT_DIR, os.path.basename(file))) - # generate() + + for function, arg_signatures in functions.items(): + args = {} + for arg_name, arg_type in arg_signatures: + if arg_type == 'int': + args[arg_name] = 42 + elif arg_type == 'str': + args[arg_name] = 'hello world' + else: + args[arg_name] = None + + invoke_signature(function, args) if __name__ == '__main__':