from collections import defaultdict from dataclasses import dataclass from typing import TypeVar, Callable, Optional from typing import Generic from nltk import edit_distance T = TypeVar('T') U = TypeVar('U') @dataclass class CmpOp(Generic[T]): operator: str name: str test: Callable[[T, T], bool] true_dist: Callable[[T, T], int] false_dist: Callable[[T, T], int] def __init__(self, operator: str, name: str, test: Callable[[T, T], bool], true_dist: Callable[[T, T], int], false_dist: Callable[[T, T], int]): self.operator = operator self.name = name self.test = test self.true_dist = true_dist self.false_dist = false_dist @dataclass class InstrState: min_true_dist: Optional[int] min_false_dist: Optional[int] def __init__(self): self.min_true_dist = None self.min_false_dist = None def update(self, op: CmpOp[U], lhs: U, rhs: U): true_dist = op.true_dist(lhs, rhs) self.min_true_dist = true_dist if self.min_true_dist is None else min(true_dist, self.min_true_dist) false_dist = op.false_dist(lhs, rhs) self.min_false_dist = false_dist if self.min_false_dist is None else min(false_dist, self.min_false_dist) instrumentation_states: defaultdict[int, InstrState] = defaultdict(InstrState) # Operands for these must both be integers or strings of length 1 int_str_ops: list[CmpOp[int | str]] = [ CmpOp(operator='<', name='Lt', test=lambda lhs, rhs: lhs < rhs, true_dist=lambda lhs, rhs: lhs - rhs + 1 if lhs >= rhs else 0, false_dist=lambda lhs, rhs: rhs - lhs if lhs < rhs else 0), CmpOp(operator='>', name='Gt', test=lambda lhs, rhs: lhs > rhs, true_dist=lambda lhs, rhs: rhs - lhs + 1 if lhs <= rhs else 0, false_dist=lambda lhs, rhs: lhs - rhs if lhs > rhs else 0), CmpOp(operator='<=', name='LtE', test=lambda lhs, rhs: lhs <= rhs, true_dist=lambda lhs, rhs: lhs - rhs if lhs > rhs else 0, false_dist=lambda lhs, rhs: rhs - lhs + 1 if lhs <= rhs else 0), CmpOp(operator='>=', name='GtE', test=lambda lhs, rhs: lhs >= rhs, true_dist=lambda lhs, rhs: rhs - lhs if lhs < rhs else 0, false_dist=lambda lhs, rhs: lhs - rhs + 1 if lhs >= rhs else 0), CmpOp(operator='==', name='Eq', test=lambda lhs, rhs: lhs == rhs, true_dist=lambda lhs, rhs: abs(lhs - rhs), false_dist=lambda lhs, rhs: 1 if lhs == rhs else 0), CmpOp(operator='!=', name='NotEq', test=lambda lhs, rhs: lhs == rhs, true_dist=lambda lhs, rhs: 1 if lhs == rhs else 0, false_dist=lambda lhs, rhs: abs(lhs - rhs)), ] int_str_by_name: dict[str, CmpOp[int | str]] = {c.name: c for c in int_str_ops} def int_str_check(a: any, b: any) -> bool: if type(a) == int and type(b) == int: return True if type(a) != str or type(b) != str: return False return len(a) == 1 or len(b) == 1 def int_str_convert(x: int | str) -> int: if type(x) == int: return x if len(x) == 1: return ord(x) raise ValueError("x must be int or len(str) == 1") # Operands for these must both be strings str_ops: list[CmpOp[str]] = [ CmpOp(operator='==', name='Eq', test=lambda lhs, rhs: lhs == rhs, true_dist=lambda lhs, rhs: edit_distance(lhs, rhs), false_dist=lambda lhs, rhs: 1 if lhs == rhs else 0), CmpOp(operator='!=', name='NotEq', test=lambda lhs, rhs: lhs == rhs, true_dist=lambda lhs, rhs: 1 if lhs == rhs else 0, false_dist=lambda lhs, rhs: edit_distance(lhs, rhs)), ] str_by_name: dict[str, CmpOp[int | str]] = {c.name: c for c in str_ops} def str_check(a: any, b: any) -> bool: return type(a) == str and type(b) == str def evaluate_condition(cmp_id: int, name: str, lhs: any, rhs: any) -> bool: if int_str_check(lhs, rhs): lhs_int = int_str_convert(lhs) rhs_int = int_str_convert(rhs) if name not in int_str_by_name: raise ValueError(f"'{name}' is not a valid CmpOp name for 'int_str' operators") op = int_str_by_name[name] instrumentation_states[cmp_id].update(op, lhs_int, rhs_int) return op.test(lhs_int, rhs_int) if str_check(lhs, rhs): if name not in str_by_name: raise ValueError(f"'{name}' is not a valid CmpOp name for 'str' operators") op = int_str_by_name[name] instrumentation_states[cmp_id].update(op, lhs, rhs) return op.test(lhs, rhs) raise ValueError(f"'{lhs}' and '{rhs}' are not suitable for both 'int_str' and 'str' operators")