220 lines
6.7 KiB
Python
220 lines
6.7 KiB
Python
from typing import Optional
|
|
import os.path
|
|
|
|
import tqdm
|
|
from frozendict import frozendict
|
|
|
|
import ast
|
|
import astunparse
|
|
import sys
|
|
import random
|
|
from operators import compute_distances
|
|
|
|
# hyperparameters
|
|
ROOT_DIR: str = os.path.dirname(__file__)
|
|
IN_DIR: str = os.path.join(ROOT_DIR, 'benchmark')
|
|
OUT_DIR: str = os.path.join(ROOT_DIR, 'instrumented')
|
|
SUFFIX: str = "_instrumented"
|
|
|
|
distances_true: dict[int, int] = {}
|
|
distances_false: dict[int, int] = {}
|
|
|
|
# Archive of solutions
|
|
archive_true_branches: dict[int, any] = {}
|
|
archive_false_branches: dict[int, any] = {}
|
|
|
|
|
|
class BranchTransformer(ast.NodeTransformer):
|
|
branches_range: dict[str, tuple[int, int]]
|
|
branch_num: int
|
|
instrumented_name: Optional[str]
|
|
in_assert: bool
|
|
in_return: bool
|
|
|
|
def __init__(self):
|
|
self.branch_num = 0
|
|
self.instrumented_name = None
|
|
self.in_assert = False
|
|
self.in_return = False
|
|
self.branches_range = {}
|
|
|
|
@staticmethod
|
|
def to_instrumented_name(name: str):
|
|
return name + SUFFIX
|
|
|
|
@staticmethod
|
|
def to_original_name(name: str):
|
|
assert name.endswith(SUFFIX)
|
|
return name[:len(name) - len(SUFFIX)]
|
|
|
|
def visit_Assert(self, ast_node):
|
|
self.in_assert = True
|
|
self.generic_visit(ast_node)
|
|
self.in_assert = False
|
|
return ast_node
|
|
|
|
def visit_Return(self, ast_node):
|
|
self.in_return = True
|
|
self.generic_visit(ast_node)
|
|
self.in_return = False
|
|
return ast_node
|
|
|
|
def visit_FunctionDef(self, ast_node):
|
|
self.instrumented_name = ast_node.name
|
|
b_start = self.branch_num
|
|
ast_node.name = BranchTransformer.to_instrumented_name(ast_node.name)
|
|
inner_node = self.generic_visit(ast_node)
|
|
self.branches_range[ast_node.name] = (b_start + 1, self.branch_num + 1)
|
|
self.instrumented_name = None
|
|
return inner_node
|
|
|
|
def visit_Call(self, ast_node):
|
|
if isinstance(ast_node.func, ast.Name) and ast_node.func.id == self.instrumented_name:
|
|
ast_node.func.id = BranchTransformer.to_instrumented_name(ast_node.func.id)
|
|
return ast_node
|
|
|
|
def visit_Compare(self, ast_node):
|
|
if ast_node.ops[0] in [ast.Is, ast.IsNot, ast.In, ast.NotIn] or self.in_assert or self.in_return:
|
|
return self.generic_visit(ast_node)
|
|
|
|
self.branch_num += 1
|
|
return ast.Call(func=ast.Name("evaluate_condition", ast.Load()),
|
|
args=[ast.Num(self.branch_num),
|
|
ast.Str(ast_node.ops[0].__class__.__name__),
|
|
ast_node.left,
|
|
ast_node.comparators[0]],
|
|
keywords=[],
|
|
starargs=None,
|
|
kwargs=None)
|
|
|
|
|
|
def update_maps(condition_num, d_true, d_false):
|
|
global distances_true, distances_false
|
|
|
|
if condition_num in distances_true.keys():
|
|
distances_true[condition_num] = min(distances_true[condition_num], d_true)
|
|
else:
|
|
distances_true[condition_num] = d_true
|
|
|
|
if condition_num in distances_false.keys():
|
|
distances_false[condition_num] = min(distances_false[condition_num], d_false)
|
|
else:
|
|
distances_false[condition_num] = d_false
|
|
|
|
|
|
def evaluate_condition(num, op, lhs, rhs): # type: ignore
|
|
if op == "In":
|
|
if isinstance(lhs, str):
|
|
lhs = ord(lhs)
|
|
|
|
minimum = sys.maxsize
|
|
for elem in rhs.keys():
|
|
distance = abs(lhs - ord(elem))
|
|
if distance < minimum:
|
|
minimum = distance
|
|
|
|
distance_true, distance_false = minimum, 1 if minimum == 0 else 0
|
|
else:
|
|
distance_true, distance_false = compute_distances(op, lhs, rhs)
|
|
|
|
update_maps(num, distance_true, distance_false)
|
|
|
|
# distance == 0 equivalent to actual test by construction
|
|
return distance_true == 0
|
|
|
|
|
|
ArgType = str
|
|
Arg = tuple[str, ArgType]
|
|
Params = frozendict[str, any]
|
|
SignatureDict = dict[str, list[Arg]]
|
|
|
|
n_of_branches: dict[str, tuple[int, int]] = {}
|
|
functions: SignatureDict = {}
|
|
module_of: dict[str, str] = {}
|
|
|
|
|
|
def instrument(source_path: str, target_path: str, save_instrumented=True):
|
|
global functions, n_of_branches
|
|
|
|
with open(source_path, "r") as f:
|
|
source = f.read()
|
|
|
|
node = ast.parse(source)
|
|
# print(ast.dump(node, indent=2))
|
|
b = BranchTransformer()
|
|
b.visit(node)
|
|
|
|
for f_name, limits in b.branches_range.items():
|
|
n_of_branches[f_name] = limits
|
|
|
|
node = ast.fix_missing_locations(node) # Make sure the line numbers are ok before printing
|
|
|
|
if save_instrumented:
|
|
with open(target_path, "w") as f:
|
|
print(astunparse.unparse(node), file=f)
|
|
|
|
current_module = sys.modules[__name__]
|
|
code = compile(node, filename="<ast>", mode="exec")
|
|
exec(code, current_module.__dict__)
|
|
|
|
# Figure out the top level function definitions
|
|
assert isinstance(node, ast.Module)
|
|
top_level_f_ast: list[ast.FunctionDef] = [f for f in node.body if isinstance(f, ast.FunctionDef)]
|
|
|
|
for f in top_level_f_ast:
|
|
if f.name in functions:
|
|
raise ValueError(f"Function '{f.name}' already loaded from another file")
|
|
|
|
arg_types: list[Arg] = []
|
|
|
|
for arg in f.args.args:
|
|
# fetch annotation type if found else fetch none
|
|
# noinspection PyUnresolvedReferences
|
|
arg_type = None if arg.annotation is None else arg.annotation.id
|
|
|
|
arg_types.append((arg.arg, arg_type))
|
|
|
|
functions[f.name] = arg_types
|
|
module_of[f.name] = os.path.normpath(os.path.relpath(source_path, ROOT_DIR)) \
|
|
.replace(".py", "") \
|
|
.replace("/", ".")
|
|
|
|
|
|
def invoke(f_name: str, f_args: Params) -> any:
|
|
global functions
|
|
|
|
current_module = sys.modules[__name__]
|
|
|
|
if f_name not in functions:
|
|
raise ValueError(f"Function '{f_name}' not loaded")
|
|
|
|
f_args_signature = functions[f_name]
|
|
for arg_name, arg_type in f_args_signature:
|
|
if arg_name not in f_args:
|
|
raise ValueError(f"Required argument '{arg_name}' not provided")
|
|
|
|
return getattr(current_module, f_name)(**f_args)
|
|
|
|
|
|
def find_py_files(search_dir: str):
|
|
for (cwd, dirs, files) in os.walk(search_dir):
|
|
for file in files:
|
|
if file.endswith(".py"):
|
|
yield os.path.join(cwd, file)
|
|
|
|
|
|
def load_benchmark(save_instrumented=True):
|
|
for file in tqdm.tqdm(find_py_files(IN_DIR), desc="Instrumenting"):
|
|
instrument(file, os.path.join(OUT_DIR, os.path.basename(file)), save_instrumented=save_instrumented)
|
|
|
|
|
|
def call_statement(f_name: str, f_args: Params) -> str:
|
|
arg_list: list[str] = []
|
|
for k, v in f_args.items():
|
|
arg_list.append(f"{k}={repr(v)}") # quote strings
|
|
|
|
return f"{f_name}({', '.join(arg_list)})"
|
|
|
|
|
|
if __name__ == '__main__':
|
|
load_benchmark(save_instrumented=True)
|