This repository has been archived on 2024-10-22. You can view files and clone it, but cannot push or open issues or pull requests.
kse-02/instrument.py
2023-11-19 14:52:52 +01:00

322 lines
9.8 KiB
Python

from typing import Optional
import os.path
import tqdm
from frozendict import frozendict
import ast
import astunparse
import sys
import random
from deap import creator, base, tools, algorithms
from operators import compute_distances
# hyperparameters
NPOP = 300
NGEN = 200
INDMUPROB = 0.05
MUPROB = 0.1
CXPROB = 0.5
TOURNSIZE = 3
LOW = -1000
UP = 1000
REPS = 10
MAX_STRING_LENGTH = 10
ROOT_DIR: str = os.path.dirname(__file__)
IN_DIR: str = os.path.join(ROOT_DIR, 'benchmark')
OUT_DIR: str = os.path.join(ROOT_DIR, 'instrumented')
SUFFIX: str = "_instrumented"
distances_true: dict[int, int] = {}
distances_false: dict[int, int] = {}
branches: list[int] = [1, 2, 3, 4, 5]
archive_true_branches: dict[int, str] = {}
archive_false_branches: dict[int, str] = {}
class BranchTransformer(ast.NodeTransformer):
branch_num: int
instrumented_name: Optional[str]
in_assert: bool
in_return: bool
def __init__(self):
self.branch_num = 0
self.instrumented_name = None
self.in_assert = False
self.in_return = False
@staticmethod
def to_instrumented_name(name: str):
return name + SUFFIX
@staticmethod
def to_original_name(name: str):
assert name.endswith(SUFFIX)
return name[:len(name) - len(SUFFIX)]
def visit_Assert(self, ast_node):
self.in_assert = True
self.generic_visit(ast_node)
self.in_assert = False
return ast_node
def visit_Return(self, ast_node):
self.in_return = True
self.generic_visit(ast_node)
self.in_return = False
return ast_node
def visit_FunctionDef(self, ast_node):
self.instrumented_name = ast_node.name
ast_node.name = BranchTransformer.to_instrumented_name(ast_node.name)
inner_node = self.generic_visit(ast_node)
self.instrumented_name = None
return inner_node
def visit_Call(self, ast_node):
if isinstance(ast_node.func, ast.Name) and ast_node.func.id == self.instrumented_name:
ast_node.func.id = BranchTransformer.to_instrumented_name(ast_node.func.id)
return ast_node
def visit_Compare(self, ast_node):
if ast_node.ops[0] in [ast.Is, ast.IsNot, ast.In, ast.NotIn] or self.in_assert or self.in_return:
return self.generic_visit(ast_node)
self.branch_num += 1
return ast.Call(func=ast.Name("evaluate_condition", ast.Load()),
args=[ast.Num(self.branch_num),
ast.Str(ast_node.ops[0].__class__.__name__),
ast_node.left,
ast_node.comparators[0]],
keywords=[],
starargs=None,
kwargs=None)
def update_maps(condition_num, d_true, d_false):
global distances_true, distances_false
if condition_num in distances_true.keys():
distances_true[condition_num] = min(distances_true[condition_num], d_true)
else:
distances_true[condition_num] = d_true
if condition_num in distances_false.keys():
distances_false[condition_num] = min(distances_false[condition_num], d_false)
else:
distances_false[condition_num] = d_false
def evaluate_condition(num, op, lhs, rhs): # type: ignore
if op == "In":
if isinstance(lhs, str):
lhs = ord(lhs)
minimum = sys.maxsize
for elem in rhs.keys():
distance = abs(lhs - ord(elem))
if distance < minimum:
minimum = distance
distance_true, distance_false = minimum, 1 if minimum == 0 else 0
else:
distance_true, distance_false = compute_distances(op, lhs, rhs)
update_maps(num, distance_true, distance_false)
# distance == 0 equivalent to actual test by construction
return distance_true == 0
def normalize(x):
return x / (1.0 + x)
def get_fitness_cgi(individual):
x = individual[0]
# Reset any distance values from previous executions
global distances_true, distances_false
global branches, archive_true_branches, archive_false_branches
distances_true = {}
distances_false = {}
# TODO: fix this
# Run the function under test
# try:
# cgi_decode_instrumented(x)
# except BaseException:
# pass
# Sum up branch distances
fitness = 0.0
for branch in branches:
if branch in distances_true:
if distances_true[branch] == 0 and branch not in archive_true_branches:
archive_true_branches[branch] = x
if branch not in archive_true_branches:
fitness += normalize(distances_true[branch])
for branch in branches:
if branch in distances_false:
if distances_false[branch] == 0 and branch not in archive_false_branches:
archive_false_branches[branch] = x
if branch not in archive_false_branches:
fitness += normalize(distances_false[branch])
return fitness,
def random_string():
length = random.randint(0, MAX_STRING_LENGTH)
s = ""
for i in range(length):
random_character = chr(random.randrange(32, 127))
s = s + random_character
return s
def crossover(individual1, individual2):
parent1 = individual1[0]
parent2 = individual2[0]
if len(parent1) > 1 and len(parent2) > 1:
pos = random.randint(1, len(parent1))
offspring1 = parent1[:pos] + parent2[pos:]
offspring2 = parent2[:pos] + parent1[pos:]
individual1[0] = offspring1
individual2[0] = offspring2
return individual1, individual2
def mutate(individual):
chromosome = individual[0]
mutated = chromosome[:]
if len(mutated) > 0:
prob = 1.0 / len(mutated)
for pos in range(len(mutated)):
if random.random() < prob:
new_c = chr(random.randrange(32, 127))
mutated = mutated[:pos] + new_c + mutated[pos + 1:]
individual[0] = mutated
return individual,
def generate():
global archive_true_branches, archive_false_branches
creator.create("Fitness", base.Fitness, weights=(-1.0,))
creator.create("Individual", list, fitness=creator.Fitness)
toolbox = base.Toolbox()
toolbox.register("attr_str", random_string)
toolbox.register("individual", tools.initRepeat, creator.Individual, toolbox.attr_str, n=1)
toolbox.register("population", tools.initRepeat, list, toolbox.individual)
toolbox.register("evaluate", get_fitness_cgi)
toolbox.register("mate", crossover)
toolbox.register("mutate", mutate)
toolbox.register("select", tools.selTournament, tournsize=TOURNSIZE)
coverage = []
for i in range(REPS):
archive_true_branches = {}
archive_false_branches = {}
population = toolbox.population(n=NPOP)
algorithms.eaSimple(population, toolbox, CXPROB, MUPROB, NGEN, verbose=False)
cov = len(archive_true_branches) + len(archive_false_branches)
print(cov, archive_true_branches, archive_false_branches)
coverage.append(cov)
ArgType = str
Arg = tuple[str, ArgType]
Params = frozendict[str, any]
SignatureDict = dict[str, list[Arg]]
functions: SignatureDict = {}
module_of: dict[str, str] = {}
def instrument(source_path: str, target_path: str, save_instrumented=True):
global functions
with open(source_path, "r") as f:
source = f.read()
node = ast.parse(source)
# print(ast.dump(node, indent=2))
BranchTransformer().visit(node)
node = ast.fix_missing_locations(node) # Make sure the line numbers are ok before printing
if save_instrumented:
with open(target_path, "w") as f:
print(astunparse.unparse(node), file=f)
current_module = sys.modules[__name__]
code = compile(node, filename="<ast>", mode="exec")
exec(code, current_module.__dict__)
# Figure out the top level function definitions
assert isinstance(node, ast.Module)
top_level_f_ast: list[ast.FunctionDef] = [f for f in node.body if isinstance(f, ast.FunctionDef)]
for f in top_level_f_ast:
if f.name in functions:
raise ValueError(f"Function '{f.name}' already loaded from another file")
arg_types: list[Arg] = []
for arg in f.args.args:
# fetch annotation type if found else fetch none
# noinspection PyUnresolvedReferences
arg_type = None if arg.annotation is None else arg.annotation.id
arg_types.append((arg.arg, arg_type))
functions[f.name] = arg_types
module_of[f.name] = os.path.normpath(os.path.relpath(source_path, ROOT_DIR)) \
.replace(".py", "") \
.replace("/", ".")
def invoke(f_name: str, f_args: Params) -> any:
global functions
current_module = sys.modules[__name__]
if f_name not in functions:
raise ValueError(f"Function '{f_name}' not loaded")
f_args_signature = functions[f_name]
for arg_name, arg_type in f_args_signature:
if arg_name not in f_args:
raise ValueError(f"Required argument '{arg_name}' not provided")
return getattr(current_module, f_name)(**f_args)
def find_py_files(search_dir: str):
for (cwd, dirs, files) in os.walk(search_dir):
for file in files:
if file.endswith(".py"):
yield os.path.join(cwd, file)
def load_benchmark(save_instrumented=True):
for file in tqdm.tqdm(find_py_files(IN_DIR), desc="Instrumenting"):
instrument(file, os.path.join(OUT_DIR, os.path.basename(file)), save_instrumented=save_instrumented)
def call_statement(f_name: str, f_args: Params) -> str:
arg_list: list[str] = []
for k, v in f_args.items():
if type(v) == str:
arg_list.append(f"{k}='{v}'") # quote strings
else:
arg_list.append(f"{k}={v}")
return f"{f_name}({', '.join(arg_list)})"
if __name__ == '__main__':
load_benchmark(save_instrumented=True)