This repository has been archived on 2024-10-22. You can view files and clone it, but cannot push or open issues or pull requests.
kse-02/instrument.py
2023-12-18 15:13:31 +01:00

191 lines
5.9 KiB
Python

import ast
import os.path
import sys
from collections import defaultdict
from typing import Optional
import astunparse
import tqdm
from frozendict import frozendict
from operators import evaluate_condition
ROOT_DIR: str = os.path.dirname(__file__)
IN_DIR: str = os.path.join(ROOT_DIR, 'benchmark')
OUT_DIR: str = os.path.join(ROOT_DIR, 'instrumented')
SUFFIX: str = "_instrumented"
class BranchTransformer(ast.NodeTransformer):
branches_range: dict[str, tuple[int, int]]
branch_num: int
instrumented_name: Optional[str]
in_assert: bool
in_return: bool
def __init__(self):
self.branch_num = 0
self.instrumented_name = None
self.in_assert = False
self.in_return = False
self.branches_range = {}
@staticmethod
def to_instrumented_name(name: str):
return name + SUFFIX
@staticmethod
def to_original_name(name: str):
assert name.endswith(SUFFIX)
return name[:len(name) - len(SUFFIX)]
def visit_Assert(self, ast_node):
self.in_assert = True
self.generic_visit(ast_node)
self.in_assert = False
return ast_node
def visit_Return(self, ast_node):
self.in_return = True
self.generic_visit(ast_node)
self.in_return = False
return ast_node
def visit_FunctionDef(self, ast_node):
self.instrumented_name = ast_node.name
b_start = self.branch_num
ast_node.name = BranchTransformer.to_instrumented_name(ast_node.name)
inner_node = self.generic_visit(ast_node)
self.branches_range[ast_node.name] = (b_start + 1, self.branch_num + 1)
self.instrumented_name = None
return inner_node
def visit_Call(self, ast_node):
if isinstance(ast_node.func, ast.Name) and ast_node.func.id == self.instrumented_name:
ast_node.func.id = BranchTransformer.to_instrumented_name(ast_node.func.id)
return ast_node
def visit_Compare(self, ast_node):
if ast_node.ops[0] in [ast.Is, ast.IsNot, ast.In, ast.NotIn] or self.in_assert or self.in_return:
return self.generic_visit(ast_node)
self.branch_num += 1
return ast.Call(func=ast.Name(evaluate_condition.__name__, ast.Load()),
args=[ast.Num(self.branch_num),
ast.Str(ast_node.ops[0].__class__.__name__),
ast_node.left,
ast_node.comparators[0]],
keywords=[],
starargs=None,
kwargs=None)
ArgType = str
Arg = tuple[str, ArgType]
Params = frozendict[str, any]
SignatureDict = dict[str, list[Arg]]
n_of_branches: dict[str, tuple[int, int]] = {}
functions: SignatureDict = {}
module_of: dict[str, list[str]] = {}
def instrument(source_path: str, target_path: str, save_instrumented=True):
global functions, n_of_branches
with open(source_path, "r") as f:
source = f.read()
node = ast.parse(source)
# print(ast.dump(node, indent=2))
b = BranchTransformer()
b.visit(node)
for f_name, limits in b.branches_range.items():
n_of_branches[f_name] = limits
node = ast.fix_missing_locations(node) # Make sure the line numbers are ok before printing
if save_instrumented:
with open(target_path, "w") as f:
print(astunparse.unparse(node), file=f)
current_module = sys.modules[__name__]
code = compile(node, filename="<ast>", mode="exec")
exec(code, current_module.__dict__)
# Figure out the top level function definitions
assert isinstance(node, ast.Module)
top_level_f_ast: list[ast.FunctionDef] = [f for f in node.body if isinstance(f, ast.FunctionDef)]
for f in top_level_f_ast:
arg_types: list[Arg] = []
for arg in f.args.args:
# fetch annotation type if found else fetch none
# noinspection PyUnresolvedReferences
arg_type = None if arg.annotation is None else arg.annotation.id
arg_types.append((arg.arg, arg_type))
functions[f.name] = arg_types
module_of[f.name] = os.path.splitext(os.path.normpath(os.path.relpath(source_path, ROOT_DIR)))[0].split("/")
def invoke(f_name: str, f_args: Params) -> any:
global functions
current_module = sys.modules[__name__]
if f_name not in functions:
raise ValueError(f"Function '{f_name}' not loaded")
f_args_signature = functions[f_name]
for arg_name, arg_type in f_args_signature:
if arg_name not in f_args:
raise ValueError(f"Required argument '{arg_name}' not provided")
return getattr(current_module, f_name)(**f_args)
def find_py_files(search_dir: str):
for (cwd, dirs, files) in os.walk(search_dir):
for file in files:
if file.endswith(".py"):
yield os.path.join(cwd, file)
def load_benchmark(save_instrumented=True, files: list[str] = ()):
to_load = set([os.path.splitext(os.path.basename(file))[0] + ".py" for file in files])
do_all = len(to_load) == 0
for file in tqdm.tqdm(find_py_files(IN_DIR), desc="Instrumenting"):
filename = os.path.basename(file)
if do_all or filename in to_load:
instrument(file, os.path.join(OUT_DIR, filename), save_instrumented=save_instrumented)
def call_statement(f_name: str, f_args: Params) -> str:
arg_list: list[str] = []
for k, v in f_args.items():
arg_list.append(f"{k}={repr(v)}") # quote strings
return f"{f_name}({', '.join(arg_list)})"
def get_benchmark() -> dict[str, list[str]]:
"""
Returns a dictionary associated each source code file name loaded (without extension) with the list of
(non-instrumented) function names defined within it
"""
names: defaultdict[str, list[str]] = defaultdict(list)
for f in functions:
names[module_of[f][-1]].append(BranchTransformer.to_original_name(f))
return dict(names)
if __name__ == '__main__':
load_benchmark(save_instrumented=True)