141 lines
4.6 KiB
Python
141 lines
4.6 KiB
Python
from dataclasses import dataclass
|
|
from typing import TypeVar, Callable
|
|
from typing import Generic
|
|
|
|
from nltk import edit_distance
|
|
|
|
T = TypeVar('T')
|
|
U = TypeVar('U')
|
|
|
|
|
|
@dataclass
|
|
class CmpOp(Generic[T]):
|
|
operator: str
|
|
name: str
|
|
test: Callable[[T, T], bool]
|
|
true_dist: Callable[[T, T], int]
|
|
false_dist: Callable[[T, T], int]
|
|
|
|
def __init__(self, operator: str, name: str, test: Callable[[T, T], bool], true_dist: Callable[[T, T], int],
|
|
false_dist: Callable[[T, T], int]):
|
|
self.operator = operator
|
|
self.name = name
|
|
self.test = test
|
|
self.true_dist = true_dist
|
|
self.false_dist = false_dist
|
|
|
|
|
|
# @dataclass
|
|
# class InstrState:
|
|
# min_true_dist: Optional[int]
|
|
# min_false_dist: Optional[int]
|
|
#
|
|
# def __init__(self):
|
|
# self.min_true_dist = None
|
|
# self.min_false_dist = None
|
|
#
|
|
# def update(self, op: CmpOp[U], lhs: U, rhs: U):
|
|
# true_dist = op.true_dist(lhs, rhs)
|
|
# self.min_true_dist = true_dist if self.min_true_dist is None else min(true_dist, self.min_true_dist)
|
|
#
|
|
# false_dist = op.false_dist(lhs, rhs)
|
|
# self.min_false_dist = false_dist if self.min_false_dist is None else min(false_dist, self.min_false_dist)
|
|
#
|
|
#
|
|
# instrumentation_states: defaultdict[int, InstrState] = defaultdict(InstrState)
|
|
|
|
|
|
# Operands for these must both be integers or strings of length 1
|
|
int_str_ops: list[CmpOp[int | str]] = [
|
|
CmpOp(operator='<',
|
|
name='Lt',
|
|
test=lambda lhs, rhs: lhs < rhs,
|
|
true_dist=lambda lhs, rhs: lhs - rhs + 1 if lhs >= rhs else 0,
|
|
false_dist=lambda lhs, rhs: rhs - lhs if lhs < rhs else 0),
|
|
CmpOp(operator='>',
|
|
name='Gt',
|
|
test=lambda lhs, rhs: lhs > rhs,
|
|
true_dist=lambda lhs, rhs: rhs - lhs + 1 if lhs <= rhs else 0,
|
|
false_dist=lambda lhs, rhs: lhs - rhs if lhs > rhs else 0),
|
|
CmpOp(operator='<=',
|
|
name='LtE',
|
|
test=lambda lhs, rhs: lhs <= rhs,
|
|
true_dist=lambda lhs, rhs: lhs - rhs if lhs > rhs else 0,
|
|
false_dist=lambda lhs, rhs: rhs - lhs + 1 if lhs <= rhs else 0),
|
|
CmpOp(operator='>=',
|
|
name='GtE',
|
|
test=lambda lhs, rhs: lhs >= rhs,
|
|
true_dist=lambda lhs, rhs: rhs - lhs if lhs < rhs else 0,
|
|
false_dist=lambda lhs, rhs: lhs - rhs + 1 if lhs >= rhs else 0),
|
|
CmpOp(operator='==',
|
|
name='Eq',
|
|
test=lambda lhs, rhs: lhs == rhs,
|
|
true_dist=lambda lhs, rhs: abs(lhs - rhs),
|
|
false_dist=lambda lhs, rhs: 1 if lhs == rhs else 0),
|
|
CmpOp(operator='!=',
|
|
name='NotEq',
|
|
test=lambda lhs, rhs: lhs == rhs,
|
|
true_dist=lambda lhs, rhs: 1 if lhs == rhs else 0,
|
|
false_dist=lambda lhs, rhs: abs(lhs - rhs)),
|
|
]
|
|
|
|
int_str_by_name: dict[str, CmpOp[int | str]] = {c.name: c for c in int_str_ops}
|
|
|
|
|
|
def int_str_check(a: any, b: any) -> bool:
|
|
if type(a) == int and type(b) == int:
|
|
return True
|
|
if type(a) != str or type(b) != str:
|
|
return False
|
|
return len(a) == 1 and len(b) == 1
|
|
|
|
|
|
def int_str_convert(x: int | str) -> int:
|
|
if type(x) == int:
|
|
return x
|
|
if len(x) == 1:
|
|
return ord(x)
|
|
|
|
raise ValueError("x must be int or len(str) == 1")
|
|
|
|
|
|
# Operands for these must both be strings
|
|
str_ops: list[CmpOp[str]] = [
|
|
CmpOp(operator='==',
|
|
name='Eq',
|
|
test=lambda lhs, rhs: lhs == rhs,
|
|
true_dist=lambda lhs, rhs: edit_distance(lhs, rhs),
|
|
false_dist=lambda lhs, rhs: 1 if lhs == rhs else 0),
|
|
CmpOp(operator='!=',
|
|
name='NotEq',
|
|
test=lambda lhs, rhs: lhs == rhs,
|
|
true_dist=lambda lhs, rhs: 1 if lhs == rhs else 0,
|
|
false_dist=lambda lhs, rhs: edit_distance(lhs, rhs)),
|
|
]
|
|
|
|
str_by_name: dict[str, CmpOp[int | str]] = {c.name: c for c in str_ops}
|
|
|
|
|
|
def str_check(a: any, b: any) -> bool:
|
|
return type(a) == str and type(b) == str
|
|
|
|
|
|
def compute_distances(name: str, lhs: any, rhs: any) -> tuple[int, int]:
|
|
if int_str_check(lhs, rhs):
|
|
lhs_int = int_str_convert(lhs)
|
|
rhs_int = int_str_convert(rhs)
|
|
|
|
if name not in int_str_by_name:
|
|
raise ValueError(f"'{name}' is not a valid CmpOp name for 'int_str' operators")
|
|
|
|
op = int_str_by_name[name]
|
|
return op.true_dist(lhs_int, rhs_int), op.false_dist(lhs_int, rhs_int)
|
|
|
|
if str_check(lhs, rhs):
|
|
if name not in str_by_name:
|
|
raise ValueError(f"'{name}' is not a valid CmpOp name for 'str' operators")
|
|
|
|
op = str_by_name[name]
|
|
return op.true_dist(lhs, rhs), op.false_dist(lhs, rhs)
|
|
|
|
raise ValueError(f"'{lhs}' and '{rhs}' are not suitable for both 'int_str' and 'str' operators")
|