This repository has been archived on 2024-10-22. You can view files and clone it, but cannot push or open issues or pull requests.
kse-02/instrument.py

192 lines
5.9 KiB
Python
Raw Normal View History

2023-12-09 16:56:04 +00:00
import ast
2023-11-13 15:33:20 +00:00
import os.path
2023-12-09 16:56:04 +00:00
import sys
2023-12-09 19:52:07 +00:00
from collections import defaultdict
2023-12-09 16:56:04 +00:00
from typing import Optional
2023-11-13 15:33:20 +00:00
2023-12-09 16:56:04 +00:00
import astunparse
2023-11-15 17:23:53 +00:00
import tqdm
from frozendict import frozendict
2023-12-09 19:52:07 +00:00
2023-12-09 16:56:04 +00:00
from operators import evaluate_condition
2023-11-15 17:23:53 +00:00
ROOT_DIR: str = os.path.dirname(__file__)
IN_DIR: str = os.path.join(ROOT_DIR, 'benchmark')
OUT_DIR: str = os.path.join(ROOT_DIR, 'instrumented')
SUFFIX: str = "_instrumented"
2023-11-13 15:33:20 +00:00
class BranchTransformer(ast.NodeTransformer):
2023-12-09 10:56:23 +00:00
branches_range: dict[str, tuple[int, int]]
2023-11-13 15:33:20 +00:00
branch_num: int
instrumented_name: Optional[str]
2023-11-19 13:52:52 +00:00
in_assert: bool
in_return: bool
2023-11-13 15:33:20 +00:00
def __init__(self):
self.branch_num = 0
self.instrumented_name = None
2023-11-19 13:52:52 +00:00
self.in_assert = False
self.in_return = False
2023-12-09 10:56:23 +00:00
self.branches_range = {}
2023-11-13 15:33:20 +00:00
@staticmethod
def to_instrumented_name(name: str):
return name + SUFFIX
@staticmethod
def to_original_name(name: str):
assert name.endswith(SUFFIX)
return name[:len(name) - len(SUFFIX)]
2023-11-13 15:33:20 +00:00
def visit_Assert(self, ast_node):
2023-11-19 13:52:52 +00:00
self.in_assert = True
self.generic_visit(ast_node)
self.in_assert = False
2023-11-13 15:33:20 +00:00
return ast_node
def visit_Return(self, ast_node):
2023-11-19 13:52:52 +00:00
self.in_return = True
self.generic_visit(ast_node)
self.in_return = False
2023-11-13 15:33:20 +00:00
return ast_node
def visit_FunctionDef(self, ast_node):
self.instrumented_name = ast_node.name
2023-12-09 10:56:23 +00:00
b_start = self.branch_num
ast_node.name = BranchTransformer.to_instrumented_name(ast_node.name)
2023-11-13 15:33:20 +00:00
inner_node = self.generic_visit(ast_node)
self.branches_range[ast_node.name] = (b_start + 1, self.branch_num + 1)
2023-11-13 15:33:20 +00:00
self.instrumented_name = None
return inner_node
def visit_Call(self, ast_node):
if isinstance(ast_node.func, ast.Name) and ast_node.func.id == self.instrumented_name:
ast_node.func.id = BranchTransformer.to_instrumented_name(ast_node.func.id)
return ast_node
def visit_Compare(self, ast_node):
2023-11-19 13:52:52 +00:00
if ast_node.ops[0] in [ast.Is, ast.IsNot, ast.In, ast.NotIn] or self.in_assert or self.in_return:
return self.generic_visit(ast_node)
2023-11-13 15:33:20 +00:00
self.branch_num += 1
2023-12-09 16:56:04 +00:00
return ast.Call(func=ast.Name(evaluate_condition.__name__, ast.Load()),
2023-11-13 15:33:20 +00:00
args=[ast.Num(self.branch_num),
ast.Str(ast_node.ops[0].__class__.__name__),
ast_node.left,
ast_node.comparators[0]],
keywords=[],
starargs=None,
kwargs=None)
ArgType = str
Arg = tuple[str, ArgType]
2023-11-15 17:23:53 +00:00
Params = frozendict[str, any]
SignatureDict = dict[str, list[Arg]]
2023-12-09 10:56:23 +00:00
n_of_branches: dict[str, tuple[int, int]] = {}
functions: SignatureDict = {}
2023-12-09 19:52:07 +00:00
module_of: dict[str, list[str]] = {}
def instrument(source_path: str, target_path: str, save_instrumented=True):
2023-12-09 10:56:23 +00:00
global functions, n_of_branches
2023-11-13 15:33:20 +00:00
with open(source_path, "r") as f:
source = f.read()
node = ast.parse(source)
# print(ast.dump(node, indent=2))
2023-12-09 10:56:23 +00:00
b = BranchTransformer()
b.visit(node)
for f_name, limits in b.branches_range.items():
n_of_branches[f_name] = limits
2023-11-13 15:33:20 +00:00
node = ast.fix_missing_locations(node) # Make sure the line numbers are ok before printing
if save_instrumented:
with open(target_path, "w") as f:
print(astunparse.unparse(node), file=f)
2023-11-13 15:33:20 +00:00
current_module = sys.modules[__name__]
code = compile(node, filename="<ast>", mode="exec")
exec(code, current_module.__dict__)
# Figure out the top level function definitions
assert isinstance(node, ast.Module)
top_level_f_ast: list[ast.FunctionDef] = [f for f in node.body if isinstance(f, ast.FunctionDef)]
for f in top_level_f_ast:
arg_types: list[Arg] = []
for arg in f.args.args:
# fetch annotation type if found else fetch none
# noinspection PyUnresolvedReferences
arg_type = None if arg.annotation is None else arg.annotation.id
arg_types.append((arg.arg, arg_type))
functions[f.name] = arg_types
2023-12-09 19:52:07 +00:00
module_of[f.name] = os.path.splitext(os.path.normpath(os.path.relpath(source_path, ROOT_DIR)))[0].split("/")
2023-11-15 17:23:53 +00:00
def invoke(f_name: str, f_args: Params) -> any:
global functions
current_module = sys.modules[__name__]
if f_name not in functions:
raise ValueError(f"Function '{f_name}' not loaded")
f_args_signature = functions[f_name]
for arg_name, arg_type in f_args_signature:
if arg_name not in f_args:
raise ValueError(f"Required argument '{arg_name}' not provided")
return getattr(current_module, f_name)(**f_args)
2023-11-13 15:33:20 +00:00
def find_py_files(search_dir: str):
for (cwd, dirs, files) in os.walk(search_dir):
for file in files:
if file.endswith(".py"):
yield os.path.join(cwd, file)
2023-12-09 16:56:04 +00:00
def load_benchmark(save_instrumented=True, files: list[str] = ()):
to_load = set([os.path.splitext(os.path.basename(file))[0] + ".py" for file in files])
do_all = len(to_load) == 0
2023-11-15 17:23:53 +00:00
for file in tqdm.tqdm(find_py_files(IN_DIR), desc="Instrumenting"):
2023-12-09 16:56:04 +00:00
filename = os.path.basename(file)
if do_all or filename in to_load:
instrument(file, os.path.join(OUT_DIR, filename), save_instrumented=save_instrumented)
2023-11-15 17:23:53 +00:00
def call_statement(f_name: str, f_args: Params) -> str:
arg_list: list[str] = []
for k, v in f_args.items():
2023-12-09 11:43:16 +00:00
arg_list.append(f"{k}={repr(v)}") # quote strings
return f"{f_name}({', '.join(arg_list)})"
2023-11-13 15:33:20 +00:00
2023-12-09 19:52:07 +00:00
def get_benchmark() -> dict[str, list[str]]:
"""
Returns a dictionary associated each source code file name loaded (without extension) with the list of
(non-instrumented) function names defined within it
"""
names: defaultdict[str, list[str]] = defaultdict(list)
for f in functions:
names[module_of[f][-1]].append(BranchTransformer.to_original_name(f))
return dict(names)
2023-11-13 15:33:20 +00:00
if __name__ == '__main__':
load_benchmark(save_instrumented=True)