This repository has been archived on 2024-10-22. You can view files and clone it, but cannot push or open issues or pull requests.
kse-02/instrumentor.py

142 lines
4.6 KiB
Python
Raw Normal View History

2023-11-13 13:45:51 +00:00
from dataclasses import dataclass
2023-11-13 15:33:20 +00:00
from typing import TypeVar, Callable
2023-11-13 13:45:51 +00:00
from typing import Generic
from nltk import edit_distance
T = TypeVar('T')
U = TypeVar('U')
@dataclass
class CmpOp(Generic[T]):
operator: str
name: str
test: Callable[[T, T], bool]
true_dist: Callable[[T, T], int]
false_dist: Callable[[T, T], int]
def __init__(self, operator: str, name: str, test: Callable[[T, T], bool], true_dist: Callable[[T, T], int],
false_dist: Callable[[T, T], int]):
self.operator = operator
self.name = name
self.test = test
self.true_dist = true_dist
self.false_dist = false_dist
2023-11-13 15:33:20 +00:00
# @dataclass
# class InstrState:
# min_true_dist: Optional[int]
# min_false_dist: Optional[int]
#
# def __init__(self):
# self.min_true_dist = None
# self.min_false_dist = None
#
# def update(self, op: CmpOp[U], lhs: U, rhs: U):
# true_dist = op.true_dist(lhs, rhs)
# self.min_true_dist = true_dist if self.min_true_dist is None else min(true_dist, self.min_true_dist)
#
# false_dist = op.false_dist(lhs, rhs)
# self.min_false_dist = false_dist if self.min_false_dist is None else min(false_dist, self.min_false_dist)
#
#
# instrumentation_states: defaultdict[int, InstrState] = defaultdict(InstrState)
2023-11-13 13:45:51 +00:00
# Operands for these must both be integers or strings of length 1
int_str_ops: list[CmpOp[int | str]] = [
CmpOp(operator='<',
name='Lt',
test=lambda lhs, rhs: lhs < rhs,
true_dist=lambda lhs, rhs: lhs - rhs + 1 if lhs >= rhs else 0,
false_dist=lambda lhs, rhs: rhs - lhs if lhs < rhs else 0),
CmpOp(operator='>',
name='Gt',
test=lambda lhs, rhs: lhs > rhs,
true_dist=lambda lhs, rhs: rhs - lhs + 1 if lhs <= rhs else 0,
false_dist=lambda lhs, rhs: lhs - rhs if lhs > rhs else 0),
CmpOp(operator='<=',
name='LtE',
test=lambda lhs, rhs: lhs <= rhs,
true_dist=lambda lhs, rhs: lhs - rhs if lhs > rhs else 0,
false_dist=lambda lhs, rhs: rhs - lhs + 1 if lhs <= rhs else 0),
CmpOp(operator='>=',
name='GtE',
test=lambda lhs, rhs: lhs >= rhs,
true_dist=lambda lhs, rhs: rhs - lhs if lhs < rhs else 0,
false_dist=lambda lhs, rhs: lhs - rhs + 1 if lhs >= rhs else 0),
CmpOp(operator='==',
name='Eq',
test=lambda lhs, rhs: lhs == rhs,
true_dist=lambda lhs, rhs: abs(lhs - rhs),
false_dist=lambda lhs, rhs: 1 if lhs == rhs else 0),
CmpOp(operator='!=',
name='NotEq',
test=lambda lhs, rhs: lhs == rhs,
true_dist=lambda lhs, rhs: 1 if lhs == rhs else 0,
false_dist=lambda lhs, rhs: abs(lhs - rhs)),
]
int_str_by_name: dict[str, CmpOp[int | str]] = {c.name: c for c in int_str_ops}
def int_str_check(a: any, b: any) -> bool:
if type(a) == int and type(b) == int:
return True
if type(a) != str or type(b) != str:
return False
return len(a) == 1 and len(b) == 1
2023-11-13 13:45:51 +00:00
def int_str_convert(x: int | str) -> int:
if type(x) == int:
return x
if len(x) == 1:
return ord(x)
raise ValueError("x must be int or len(str) == 1")
# Operands for these must both be strings
str_ops: list[CmpOp[str]] = [
CmpOp(operator='==',
name='Eq',
test=lambda lhs, rhs: lhs == rhs,
true_dist=lambda lhs, rhs: edit_distance(lhs, rhs),
false_dist=lambda lhs, rhs: 1 if lhs == rhs else 0),
CmpOp(operator='!=',
name='NotEq',
test=lambda lhs, rhs: lhs == rhs,
true_dist=lambda lhs, rhs: 1 if lhs == rhs else 0,
false_dist=lambda lhs, rhs: edit_distance(lhs, rhs)),
]
str_by_name: dict[str, CmpOp[int | str]] = {c.name: c for c in str_ops}
def str_check(a: any, b: any) -> bool:
return type(a) == str and type(b) == str
2023-11-13 15:33:20 +00:00
def compute_distances(name: str, lhs: any, rhs: any) -> tuple[int, int]:
2023-11-13 13:45:51 +00:00
if int_str_check(lhs, rhs):
lhs_int = int_str_convert(lhs)
rhs_int = int_str_convert(rhs)
if name not in int_str_by_name:
raise ValueError(f"'{name}' is not a valid CmpOp name for 'int_str' operators")
op = int_str_by_name[name]
2023-11-13 15:33:20 +00:00
return op.true_dist(lhs_int, rhs_int), op.false_dist(lhs_int, rhs_int)
2023-11-13 13:45:51 +00:00
if str_check(lhs, rhs):
if name not in str_by_name:
raise ValueError(f"'{name}' is not a valid CmpOp name for 'str' operators")
op = str_by_name[name]
2023-11-13 15:33:20 +00:00
return op.true_dist(lhs, rhs), op.false_dist(lhs, rhs)
2023-11-13 13:45:51 +00:00
raise ValueError(f"'{lhs}' and '{rhs}' are not suitable for both 'int_str' and 'str' operators")