instrument script supports function execution

This commit is contained in:
Claudio Maggioni 2023-11-15 13:32:08 +01:00
parent af6d21dbb1
commit a68f405674
12 changed files with 80 additions and 17 deletions

View file

@ -1,6 +1,6 @@
def anagram_check(s1: str, s2: str) -> bool: def anagram_check_instrumented(s1: str, s2: str) -> bool:
if (evaluate_condition(1, 'Eq', len(s1), 1) and evaluate_condition(2, 'Eq', len(s2), 1)): if (evaluate_condition(1, 'Eq', len(s1), 1) and evaluate_condition(2, 'Eq', len(s2), 1)):
return (s1 == s2) return (s1 == s2)
if evaluate_condition(3, 'NotEq', len(s1), len(s2)): if evaluate_condition(3, 'NotEq', len(s1), len(s2)):

View file

@ -1,6 +1,6 @@
def encrypt(strng: str, key: int) -> str: def encrypt_instrumented(strng: str, key: int) -> str:
assert (0 < key <= 94) assert (0 < key <= 94)
encrypted = '' encrypted = ''
for x in strng: for x in strng:
@ -10,7 +10,7 @@ def encrypt(strng: str, key: int) -> str:
encrypted = (encrypted + chr(indx)) encrypted = (encrypted + chr(indx))
return encrypted return encrypted
def decrypt(strng: str, key: int) -> str: def decrypt_instrumented(strng: str, key: int) -> str:
assert (0 < key <= 94) assert (0 < key <= 94)
decrypted = '' decrypted = ''
for x in strng: for x in strng:

View file

@ -1,6 +1,6 @@
def check_armstrong(n: int) -> bool: def check_armstrong_instrumented(n: int) -> bool:
assert (n >= 0) assert (n >= 0)
if (evaluate_condition(1, 'Eq', n, 0) or evaluate_condition(2, 'Eq', n, 1)): if (evaluate_condition(1, 'Eq', n, 0) or evaluate_condition(2, 'Eq', n, 1)):
return True return True

View file

@ -1,7 +1,7 @@
'\nThe function takes two integers as input and return the number of common divisors of\nthat pair\n' '\nThe function takes two integers as input and return the number of common divisors of\nthat pair\n'
def cd_count(a: int, b: int) -> int: def cd_count_instrumented(a: int, b: int) -> int:
if (evaluate_condition(1, 'Eq', a, 0) or evaluate_condition(2, 'Eq', b, 0)): if (evaluate_condition(1, 'Eq', a, 0) or evaluate_condition(2, 'Eq', b, 0)):
return 2 return 2
a = (((- 1) * a) if evaluate_condition(3, 'Lt', a, 0) else a) a = (((- 1) * a) if evaluate_condition(3, 'Lt', a, 0) else a)

View file

@ -1,6 +1,6 @@
def exponentiation(baseNumber: int, power: int) -> float: def exponentiation_instrumented(baseNumber: int, power: int) -> float:
assert (not ((baseNumber == 0) or (power <= 0))) assert (not ((baseNumber == 0) or (power <= 0)))
answer = None answer = None
if evaluate_condition(1, 'Gt', power, 1): if evaluate_condition(1, 'Gt', power, 1):

View file

@ -1,6 +1,6 @@
def gcd(a: int, b: int) -> int: def gcd_instrumented(a: int, b: int) -> int:
assert ((a > 0) and (b > 0)) assert ((a > 0) and (b > 0))
if (evaluate_condition(1, 'Eq', a, 1) or evaluate_condition(2, 'Eq', b, 1)): if (evaluate_condition(1, 'Eq', a, 1) or evaluate_condition(2, 'Eq', b, 1)):
return 1 return 1

View file

@ -1,6 +1,6 @@
def longest_sorted_substr(s: str) -> str: def longest_sorted_substr_instrumented(s: str) -> str:
count = 0 count = 0
max_count = 0 max_count = 0
end_position = 0 end_position = 0

View file

@ -1,6 +1,6 @@
def rabin_karp_search(pat: str, txt: str) -> list: def rabin_karp_search_instrumented(pat: str, txt: str) -> list:
assert (len(pat) <= len(txt)) assert (len(pat) <= len(txt))
d = 2560 d = 2560
q = 101 q = 101

View file

@ -1,6 +1,6 @@
def railencrypt(st: str, k: int) -> str: def railencrypt_instrumented(st: str, k: int) -> str:
assert (k > 1) assert (k > 1)
c = 0 c = 0
x = 0 x = 0
@ -25,7 +25,7 @@ def railencrypt(st: str, k: int) -> str:
result.append(chr(m[i][j])) result.append(chr(m[i][j]))
return ''.join(result) return ''.join(result)
def raildecrypt(st: str, k: int) -> str: def raildecrypt_instrumented(st: str, k: int) -> str:
assert (k > 1) assert (k > 1)
(c, x) = (0, 0) (c, x) = (0, 0)
m = [([0] * len(st)) for i in range(k)] m = [([0] * len(st)) for i in range(k)]

View file

@ -1,6 +1,6 @@
def zeller(d: int, m: int, y: int) -> str: def zeller_instrumented(d: int, m: int, y: int) -> str:
assert (abs(d) >= 1) assert (abs(d) >= 1)
assert (abs(m) >= 1) assert (abs(m) >= 1)
assert ((0 <= abs(y) <= 99) or (1000 <= abs(y) <= 3000)) assert ((0 <= abs(y) <= 99) or (1000 <= abs(y) <= 3000))

View file

@ -87,7 +87,7 @@ def int_str_check(a: any, b: any) -> bool:
return True return True
if type(a) != str or type(b) != str: if type(a) != str or type(b) != str:
return False return False
return len(a) == 1 or len(b) == 1 return len(a) == 1 and len(b) == 1
def int_str_convert(x: int | str) -> int: def int_str_convert(x: int | str) -> int:
@ -135,7 +135,7 @@ def compute_distances(name: str, lhs: any, rhs: any) -> tuple[int, int]:
if name not in str_by_name: if name not in str_by_name:
raise ValueError(f"'{name}' is not a valid CmpOp name for 'str' operators") raise ValueError(f"'{name}' is not a valid CmpOp name for 'str' operators")
op = int_str_by_name[name] op = str_by_name[name]
return op.true_dist(lhs, rhs), op.false_dist(lhs, rhs) return op.true_dist(lhs, rhs), op.false_dist(lhs, rhs)
raise ValueError(f"'{lhs}' and '{rhs}' are not suitable for both 'int_str' and 'str' operators") raise ValueError(f"'{lhs}' and '{rhs}' are not suitable for both 'int_str' and 'str' operators")

View file

@ -56,6 +56,7 @@ class BranchTransformer(ast.NodeTransformer):
def visit_FunctionDef(self, ast_node): def visit_FunctionDef(self, ast_node):
self.instrumented_name = ast_node.name self.instrumented_name = ast_node.name
ast_node.name = BranchTransformer.to_instrumented_name(ast_node.name)
inner_node = self.generic_visit(ast_node) inner_node = self.generic_visit(ast_node)
self.instrumented_name = None self.instrumented_name = None
return inner_node return inner_node
@ -212,12 +213,22 @@ def generate():
coverage.append(cov) coverage.append(cov)
ArgType = str
Arg = tuple[str, ArgType]
SignatureDict = dict[str, list[Arg]]
functions: SignatureDict = {}
def instrument(source_path: str, target_path: str): def instrument(source_path: str, target_path: str):
global functions
with open(source_path, "r") as f: with open(source_path, "r") as f:
source = f.read() source = f.read()
node = ast.parse(source) node = ast.parse(source)
print(ast.dump(node, indent=2)) # print(ast.dump(node, indent=2))
BranchTransformer().visit(node) BranchTransformer().visit(node)
node = ast.fix_missing_locations(node) # Make sure the line numbers are ok before printing node = ast.fix_missing_locations(node) # Make sure the line numbers are ok before printing
@ -226,7 +237,46 @@ def instrument(source_path: str, target_path: str):
current_module = sys.modules[__name__] current_module = sys.modules[__name__]
code = compile(node, filename="<ast>", mode="exec") code = compile(node, filename="<ast>", mode="exec")
exec(code, current_module.__dict__) # try: cgi_decode_instrumented("a%20%32"), print distances_true exec(code, current_module.__dict__)
# Figure out the top level function definitions
assert isinstance(node, ast.Module)
top_level_f_ast: list[ast.FunctionDef] = [f for f in node.body if isinstance(f, ast.FunctionDef)]
for f in top_level_f_ast:
if f.name in functions:
raise ValueError(f"Function '{f.name}' already loaded from another file")
arg_types: list[Arg] = []
for arg in f.args.args:
# fetch annotation type if found else fetch none
# noinspection PyUnresolvedReferences
arg_type = None if arg.annotation is None else arg.annotation.id
arg_types.append((arg.arg, arg_type))
functions[f.name] = arg_types
def invoke_signature(fun_name: str, arg_values: dict[str]) -> any:
global functions
current_module = sys.modules[__name__]
if fun_name not in functions:
raise ValueError(f"Function '{fun_name}' not loaded")
args = functions[fun_name]
for arg_name, arg_type in args:
if arg_name not in arg_values:
raise ValueError(f"Required argument '{arg_name}' not provided")
arg_str = ",".join([f"{k}={v}" for k, v in arg_values.items()])
print(f"Calling {fun_name}({arg_str})")
return getattr(current_module, fun_name)(**arg_values)
def find_py_files(search_dir: str): def find_py_files(search_dir: str):
@ -237,9 +287,22 @@ def find_py_files(search_dir: str):
def main(): def main():
global functions
for file in find_py_files(IN_DIR): for file in find_py_files(IN_DIR):
instrument(file, os.path.join(OUT_DIR, os.path.basename(file))) instrument(file, os.path.join(OUT_DIR, os.path.basename(file)))
# generate()
for function, arg_signatures in functions.items():
args = {}
for arg_name, arg_type in arg_signatures:
if arg_type == 'int':
args[arg_name] = 42
elif arg_type == 'str':
args[arg_name] = 'hello world'
else:
args[arg_name] = None
invoke_signature(function, args)
if __name__ == '__main__': if __name__ == '__main__':