This repository has been archived on 2024-10-22. You can view files and clone it, but cannot push or open issues or pull requests.
kse-02/instrument.py

317 lines
9.7 KiB
Python
Raw Normal View History

2023-11-13 15:33:20 +00:00
from typing import Optional
import os.path
2023-11-15 17:23:53 +00:00
import tqdm
from frozendict import frozendict
2023-11-13 15:33:20 +00:00
import ast
import astunparse
import sys
import random
from deap import creator, base, tools, algorithms
from operators import compute_distances
2023-11-13 15:33:20 +00:00
# hyperparameters
NPOP = 300
NGEN = 200
INDMUPROB = 0.05
MUPROB = 0.1
CXPROB = 0.5
TOURNSIZE = 3
LOW = -1000
UP = 1000
REPS = 10
MAX_STRING_LENGTH = 10
2023-11-15 17:23:53 +00:00
ROOT_DIR: str = os.path.dirname(__file__)
IN_DIR: str = os.path.join(ROOT_DIR, 'benchmark')
OUT_DIR: str = os.path.join(ROOT_DIR, 'instrumented')
SUFFIX: str = "_instrumented"
2023-11-13 15:33:20 +00:00
distances_true: dict[int, int] = {}
distances_false: dict[int, int] = {}
branches: list[int] = [1, 2, 3, 4, 5]
archive_true_branches: dict[int, str] = {}
archive_false_branches: dict[int, str] = {}
class BranchTransformer(ast.NodeTransformer):
branch_num: int
instrumented_name: Optional[str]
def __init__(self):
self.branch_num = 0
self.instrumented_name = None
@staticmethod
def to_instrumented_name(name: str):
return name + SUFFIX
@staticmethod
def to_original_name(name: str):
assert name.endswith(SUFFIX)
return name[:len(name) - len(SUFFIX)]
2023-11-13 15:33:20 +00:00
def visit_Assert(self, ast_node):
# Disable recursion in asserts, i.e. do not instrument assert conditions
2023-11-15 17:23:53 +00:00
# TODO: may fail if assertion calls method (which must be renamed)
2023-11-13 15:33:20 +00:00
return ast_node
def visit_Return(self, ast_node):
# Same thing for return statements
2023-11-15 17:23:53 +00:00
# TODO: may fail if return statement calls method (which must be renamed)
2023-11-13 15:33:20 +00:00
return ast_node
def visit_FunctionDef(self, ast_node):
self.instrumented_name = ast_node.name
ast_node.name = BranchTransformer.to_instrumented_name(ast_node.name)
2023-11-13 15:33:20 +00:00
inner_node = self.generic_visit(ast_node)
self.instrumented_name = None
return inner_node
def visit_Call(self, ast_node):
if isinstance(ast_node.func, ast.Name) and ast_node.func.id == self.instrumented_name:
ast_node.func.id = BranchTransformer.to_instrumented_name(ast_node.func.id)
return ast_node
def visit_Compare(self, ast_node):
if ast_node.ops[0] in [ast.Is, ast.IsNot, ast.In, ast.NotIn]:
return ast_node
self.branch_num += 1
return ast.Call(func=ast.Name("evaluate_condition", ast.Load()),
args=[ast.Num(self.branch_num),
ast.Str(ast_node.ops[0].__class__.__name__),
ast_node.left,
ast_node.comparators[0]],
keywords=[],
starargs=None,
kwargs=None)
def update_maps(condition_num, d_true, d_false):
global distances_true, distances_false
if condition_num in distances_true.keys():
distances_true[condition_num] = min(distances_true[condition_num], d_true)
else:
distances_true[condition_num] = d_true
if condition_num in distances_false.keys():
distances_false[condition_num] = min(distances_false[condition_num], d_false)
else:
distances_false[condition_num] = d_false
def evaluate_condition(num, op, lhs, rhs): # type: ignore
if op == "In":
if isinstance(lhs, str):
lhs = ord(lhs)
minimum = sys.maxsize
for elem in rhs.keys():
distance = abs(lhs - ord(elem))
if distance < minimum:
minimum = distance
distance_true, distance_false = minimum, 1 if minimum == 0 else 0
else:
distance_true, distance_false = compute_distances(op, lhs, rhs)
update_maps(num, distance_true, distance_false)
# distance == 0 equivalent to actual test by construction
return distance_true == 0
def normalize(x):
return x / (1.0 + x)
def get_fitness_cgi(individual):
x = individual[0]
# Reset any distance values from previous executions
global distances_true, distances_false
global branches, archive_true_branches, archive_false_branches
distances_true = {}
distances_false = {}
# TODO: fix this
2023-11-13 15:33:20 +00:00
# Run the function under test
# try:
# cgi_decode_instrumented(x)
# except BaseException:
# pass
2023-11-13 15:33:20 +00:00
# Sum up branch distances
fitness = 0.0
for branch in branches:
if branch in distances_true:
if distances_true[branch] == 0 and branch not in archive_true_branches:
archive_true_branches[branch] = x
if branch not in archive_true_branches:
fitness += normalize(distances_true[branch])
for branch in branches:
if branch in distances_false:
if distances_false[branch] == 0 and branch not in archive_false_branches:
archive_false_branches[branch] = x
if branch not in archive_false_branches:
fitness += normalize(distances_false[branch])
return fitness,
def random_string():
l = random.randint(0, MAX_STRING_LENGTH)
s = ""
for i in range(l):
random_character = chr(random.randrange(32, 127))
s = s + random_character
return s
def crossover(individual1, individual2):
parent1 = individual1[0]
parent2 = individual2[0]
if len(parent1) > 1 and len(parent2) > 1:
pos = random.randint(1, len(parent1))
offspring1 = parent1[:pos] + parent2[pos:]
offspring2 = parent2[:pos] + parent1[pos:]
individual1[0] = offspring1
individual2[0] = offspring2
return individual1, individual2
def mutate(individual):
chromosome = individual[0]
mutated = chromosome[:]
if len(mutated) > 0:
prob = 1.0 / len(mutated)
for pos in range(len(mutated)):
if random.random() < prob:
new_c = chr(random.randrange(32, 127))
mutated = mutated[:pos] + new_c + mutated[pos + 1:]
individual[0] = mutated
return individual,
def generate():
global archive_true_branches, archive_false_branches
creator.create("Fitness", base.Fitness, weights=(-1.0,))
creator.create("Individual", list, fitness=creator.Fitness)
toolbox = base.Toolbox()
toolbox.register("attr_str", random_string)
toolbox.register("individual", tools.initRepeat, creator.Individual, toolbox.attr_str, n=1)
toolbox.register("population", tools.initRepeat, list, toolbox.individual)
toolbox.register("evaluate", get_fitness_cgi)
toolbox.register("mate", crossover)
toolbox.register("mutate", mutate)
toolbox.register("select", tools.selTournament, tournsize=TOURNSIZE)
coverage = []
for i in range(REPS):
archive_true_branches = {}
archive_false_branches = {}
population = toolbox.population(n=NPOP)
algorithms.eaSimple(population, toolbox, CXPROB, MUPROB, NGEN, verbose=False)
cov = len(archive_true_branches) + len(archive_false_branches)
print(cov, archive_true_branches, archive_false_branches)
coverage.append(cov)
ArgType = str
Arg = tuple[str, ArgType]
2023-11-15 17:23:53 +00:00
Params = frozendict[str, any]
SignatureDict = dict[str, list[Arg]]
functions: SignatureDict = {}
2023-11-15 17:23:53 +00:00
module_of: dict[str, str] = {}
def instrument(source_path: str, target_path: str, save_instrumented=True):
global functions
2023-11-13 15:33:20 +00:00
with open(source_path, "r") as f:
source = f.read()
node = ast.parse(source)
# print(ast.dump(node, indent=2))
2023-11-13 15:33:20 +00:00
BranchTransformer().visit(node)
node = ast.fix_missing_locations(node) # Make sure the line numbers are ok before printing
if save_instrumented:
with open(target_path, "w") as f:
print(astunparse.unparse(node), file=f)
2023-11-13 15:33:20 +00:00
current_module = sys.modules[__name__]
code = compile(node, filename="<ast>", mode="exec")
exec(code, current_module.__dict__)
# Figure out the top level function definitions
assert isinstance(node, ast.Module)
top_level_f_ast: list[ast.FunctionDef] = [f for f in node.body if isinstance(f, ast.FunctionDef)]
for f in top_level_f_ast:
if f.name in functions:
raise ValueError(f"Function '{f.name}' already loaded from another file")
arg_types: list[Arg] = []
for arg in f.args.args:
# fetch annotation type if found else fetch none
# noinspection PyUnresolvedReferences
arg_type = None if arg.annotation is None else arg.annotation.id
arg_types.append((arg.arg, arg_type))
functions[f.name] = arg_types
2023-11-15 17:23:53 +00:00
module_of[f.name] = os.path.normpath(os.path.relpath(source_path, ROOT_DIR)) \
.replace(".py", "") \
.replace("/", ".")
2023-11-15 17:23:53 +00:00
def invoke(f_name: str, f_args: Params) -> any:
global functions
current_module = sys.modules[__name__]
if f_name not in functions:
raise ValueError(f"Function '{f_name}' not loaded")
f_args_signature = functions[f_name]
for arg_name, arg_type in f_args_signature:
if arg_name not in f_args:
raise ValueError(f"Required argument '{arg_name}' not provided")
return getattr(current_module, f_name)(**f_args)
2023-11-13 15:33:20 +00:00
def find_py_files(search_dir: str):
for (cwd, dirs, files) in os.walk(search_dir):
for file in files:
if file.endswith(".py"):
yield os.path.join(cwd, file)
def load_benchmark(save_instrumented=True):
2023-11-15 17:23:53 +00:00
for file in tqdm.tqdm(find_py_files(IN_DIR), desc="Instrumenting"):
instrument(file, os.path.join(OUT_DIR, os.path.basename(file)), save_instrumented=save_instrumented)
2023-11-15 17:23:53 +00:00
def call_statement(f_name: str, f_args: Params) -> str:
arg_list: list[str] = []
for k, v in f_args.items():
if type(v) == str:
arg_list.append(f"{k}='{v}'") # quote strings
else:
arg_list.append(f"{k}={v}")
return f"{f_name}({', '.join(arg_list)})"
2023-11-13 15:33:20 +00:00
if __name__ == '__main__':
load_benchmark(save_instrumented=True)