247 lines
7.5 KiB
Python
247 lines
7.5 KiB
Python
|
from typing import Optional
|
||
|
import os.path
|
||
|
|
||
|
import ast
|
||
|
import astunparse
|
||
|
import sys
|
||
|
import random
|
||
|
from deap import creator, base, tools, algorithms
|
||
|
from instrumentor import compute_distances
|
||
|
|
||
|
# hyperparameters
|
||
|
NPOP = 300
|
||
|
NGEN = 200
|
||
|
INDMUPROB = 0.05
|
||
|
MUPROB = 0.1
|
||
|
CXPROB = 0.5
|
||
|
TOURNSIZE = 3
|
||
|
LOW = -1000
|
||
|
UP = 1000
|
||
|
REPS = 10
|
||
|
MAX_STRING_LENGTH = 10
|
||
|
|
||
|
IN_DIR: str = os.path.join(os.path.dirname(__file__), 'benchmark')
|
||
|
OUT_DIR: str = os.path.join(os.path.dirname(__file__), 'instrumented')
|
||
|
|
||
|
distances_true: dict[int, int] = {}
|
||
|
distances_false: dict[int, int] = {}
|
||
|
branches: list[int] = [1, 2, 3, 4, 5]
|
||
|
archive_true_branches: dict[int, str] = {}
|
||
|
archive_false_branches: dict[int, str] = {}
|
||
|
|
||
|
|
||
|
def cgi_decode_instrumented(s: str) -> str:
|
||
|
return "" # make mypy happy
|
||
|
|
||
|
|
||
|
class BranchTransformer(ast.NodeTransformer):
|
||
|
branch_num: int
|
||
|
instrumented_name: Optional[str]
|
||
|
|
||
|
def __init__(self):
|
||
|
self.branch_num = 0
|
||
|
self.instrumented_name = None
|
||
|
|
||
|
@staticmethod
|
||
|
def to_instrumented_name(name: str):
|
||
|
return name + "_instrumented"
|
||
|
|
||
|
def visit_Assert(self, ast_node):
|
||
|
# Disable recursion in asserts, i.e. do not instrument assert conditions
|
||
|
return ast_node
|
||
|
|
||
|
def visit_Return(self, ast_node):
|
||
|
# Same thing for return statements
|
||
|
return ast_node
|
||
|
|
||
|
def visit_FunctionDef(self, ast_node):
|
||
|
self.instrumented_name = ast_node.name
|
||
|
inner_node = self.generic_visit(ast_node)
|
||
|
self.instrumented_name = None
|
||
|
return inner_node
|
||
|
|
||
|
def visit_Call(self, ast_node):
|
||
|
if isinstance(ast_node.func, ast.Name) and ast_node.func.id == self.instrumented_name:
|
||
|
ast_node.func.id = BranchTransformer.to_instrumented_name(ast_node.func.id)
|
||
|
return ast_node
|
||
|
|
||
|
def visit_Compare(self, ast_node):
|
||
|
if ast_node.ops[0] in [ast.Is, ast.IsNot, ast.In, ast.NotIn]:
|
||
|
return ast_node
|
||
|
|
||
|
self.branch_num += 1
|
||
|
return ast.Call(func=ast.Name("evaluate_condition", ast.Load()),
|
||
|
args=[ast.Num(self.branch_num),
|
||
|
ast.Str(ast_node.ops[0].__class__.__name__),
|
||
|
ast_node.left,
|
||
|
ast_node.comparators[0]],
|
||
|
keywords=[],
|
||
|
starargs=None,
|
||
|
kwargs=None)
|
||
|
|
||
|
|
||
|
def update_maps(condition_num, d_true, d_false):
|
||
|
global distances_true, distances_false
|
||
|
|
||
|
if condition_num in distances_true.keys():
|
||
|
distances_true[condition_num] = min(distances_true[condition_num], d_true)
|
||
|
else:
|
||
|
distances_true[condition_num] = d_true
|
||
|
|
||
|
if condition_num in distances_false.keys():
|
||
|
distances_false[condition_num] = min(distances_false[condition_num], d_false)
|
||
|
else:
|
||
|
distances_false[condition_num] = d_false
|
||
|
|
||
|
|
||
|
def evaluate_condition(num, op, lhs, rhs): # type: ignore
|
||
|
if op == "In":
|
||
|
if isinstance(lhs, str):
|
||
|
lhs = ord(lhs)
|
||
|
|
||
|
minimum = sys.maxsize
|
||
|
for elem in rhs.keys():
|
||
|
distance = abs(lhs - ord(elem))
|
||
|
if distance < minimum:
|
||
|
minimum = distance
|
||
|
|
||
|
distance_true, distance_false = minimum, 1 if minimum == 0 else 0
|
||
|
else:
|
||
|
distance_true, distance_false = compute_distances(op, lhs, rhs)
|
||
|
|
||
|
update_maps(num, distance_true, distance_false)
|
||
|
|
||
|
# distance == 0 equivalent to actual test by construction
|
||
|
return distance_true == 0
|
||
|
|
||
|
|
||
|
def normalize(x):
|
||
|
return x / (1.0 + x)
|
||
|
|
||
|
|
||
|
def get_fitness_cgi(individual):
|
||
|
x = individual[0]
|
||
|
# Reset any distance values from previous executions
|
||
|
global distances_true, distances_false
|
||
|
global branches, archive_true_branches, archive_false_branches
|
||
|
distances_true = {}
|
||
|
distances_false = {}
|
||
|
|
||
|
# Run the function under test
|
||
|
try:
|
||
|
cgi_decode_instrumented(x)
|
||
|
except BaseException:
|
||
|
pass
|
||
|
|
||
|
# Sum up branch distances
|
||
|
fitness = 0.0
|
||
|
for branch in branches:
|
||
|
if branch in distances_true:
|
||
|
if distances_true[branch] == 0 and branch not in archive_true_branches:
|
||
|
archive_true_branches[branch] = x
|
||
|
if branch not in archive_true_branches:
|
||
|
fitness += normalize(distances_true[branch])
|
||
|
for branch in branches:
|
||
|
if branch in distances_false:
|
||
|
if distances_false[branch] == 0 and branch not in archive_false_branches:
|
||
|
archive_false_branches[branch] = x
|
||
|
if branch not in archive_false_branches:
|
||
|
fitness += normalize(distances_false[branch])
|
||
|
|
||
|
return fitness,
|
||
|
|
||
|
|
||
|
def random_string():
|
||
|
l = random.randint(0, MAX_STRING_LENGTH)
|
||
|
s = ""
|
||
|
for i in range(l):
|
||
|
random_character = chr(random.randrange(32, 127))
|
||
|
s = s + random_character
|
||
|
return s
|
||
|
|
||
|
|
||
|
def crossover(individual1, individual2):
|
||
|
parent1 = individual1[0]
|
||
|
parent2 = individual2[0]
|
||
|
if len(parent1) > 1 and len(parent2) > 1:
|
||
|
pos = random.randint(1, len(parent1))
|
||
|
offspring1 = parent1[:pos] + parent2[pos:]
|
||
|
offspring2 = parent2[:pos] + parent1[pos:]
|
||
|
individual1[0] = offspring1
|
||
|
individual2[0] = offspring2
|
||
|
|
||
|
return individual1, individual2
|
||
|
|
||
|
|
||
|
def mutate(individual):
|
||
|
chromosome = individual[0]
|
||
|
mutated = chromosome[:]
|
||
|
if len(mutated) > 0:
|
||
|
prob = 1.0 / len(mutated)
|
||
|
for pos in range(len(mutated)):
|
||
|
if random.random() < prob:
|
||
|
new_c = chr(random.randrange(32, 127))
|
||
|
mutated = mutated[:pos] + new_c + mutated[pos + 1:]
|
||
|
individual[0] = mutated
|
||
|
return individual,
|
||
|
|
||
|
|
||
|
def generate():
|
||
|
global archive_true_branches, archive_false_branches
|
||
|
|
||
|
creator.create("Fitness", base.Fitness, weights=(-1.0,))
|
||
|
creator.create("Individual", list, fitness=creator.Fitness)
|
||
|
|
||
|
toolbox = base.Toolbox()
|
||
|
toolbox.register("attr_str", random_string)
|
||
|
toolbox.register("individual", tools.initRepeat, creator.Individual, toolbox.attr_str, n=1)
|
||
|
toolbox.register("population", tools.initRepeat, list, toolbox.individual)
|
||
|
toolbox.register("evaluate", get_fitness_cgi)
|
||
|
toolbox.register("mate", crossover)
|
||
|
toolbox.register("mutate", mutate)
|
||
|
toolbox.register("select", tools.selTournament, tournsize=TOURNSIZE)
|
||
|
|
||
|
coverage = []
|
||
|
for i in range(REPS):
|
||
|
archive_true_branches = {}
|
||
|
archive_false_branches = {}
|
||
|
population = toolbox.population(n=NPOP)
|
||
|
algorithms.eaSimple(population, toolbox, CXPROB, MUPROB, NGEN, verbose=False)
|
||
|
cov = len(archive_true_branches) + len(archive_false_branches)
|
||
|
print(cov, archive_true_branches, archive_false_branches)
|
||
|
coverage.append(cov)
|
||
|
|
||
|
|
||
|
def instrument(source_path: str, target_path: str):
|
||
|
with open(source_path, "r") as f:
|
||
|
source = f.read()
|
||
|
|
||
|
node = ast.parse(source)
|
||
|
print(ast.dump(node, indent=2))
|
||
|
BranchTransformer().visit(node)
|
||
|
node = ast.fix_missing_locations(node) # Make sure the line numbers are ok before printing
|
||
|
|
||
|
with open(target_path, "w") as f:
|
||
|
print(astunparse.unparse(node), file=f)
|
||
|
|
||
|
current_module = sys.modules[__name__]
|
||
|
code = compile(node, filename="<ast>", mode="exec")
|
||
|
exec(code, current_module.__dict__) # try: cgi_decode_instrumented("a%20%32"), print distances_true
|
||
|
|
||
|
|
||
|
def find_py_files(search_dir: str):
|
||
|
for (cwd, dirs, files) in os.walk(search_dir):
|
||
|
for file in files:
|
||
|
if file.endswith(".py"):
|
||
|
yield os.path.join(cwd, file)
|
||
|
|
||
|
|
||
|
def main():
|
||
|
for file in find_py_files(IN_DIR):
|
||
|
instrument(file, os.path.join(OUT_DIR, os.path.basename(file)))
|
||
|
# generate()
|
||
|
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
main()
|