2023-12-09 16:56:04 +00:00
|
|
|
import sys
|
2023-11-13 13:45:51 +00:00
|
|
|
from dataclasses import dataclass
|
|
|
|
from typing import Generic
|
2023-12-09 16:56:04 +00:00
|
|
|
from typing import TypeVar, Callable
|
2023-11-13 13:45:51 +00:00
|
|
|
|
|
|
|
from nltk import edit_distance
|
|
|
|
|
2023-12-09 16:56:04 +00:00
|
|
|
distances_true: dict[int, int] = {}
|
|
|
|
distances_false: dict[int, int] = {}
|
|
|
|
|
2023-12-11 14:43:53 +00:00
|
|
|
distances_true_all: dict[int, list[int]] = {}
|
|
|
|
distances_false_all: dict[int, list[int]] = {}
|
|
|
|
|
2023-11-13 13:45:51 +00:00
|
|
|
T = TypeVar('T')
|
|
|
|
U = TypeVar('U')
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
class CmpOp(Generic[T]):
|
|
|
|
operator: str
|
|
|
|
name: str
|
|
|
|
test: Callable[[T, T], bool]
|
|
|
|
true_dist: Callable[[T, T], int]
|
|
|
|
false_dist: Callable[[T, T], int]
|
|
|
|
|
|
|
|
def __init__(self, operator: str, name: str, test: Callable[[T, T], bool], true_dist: Callable[[T, T], int],
|
|
|
|
false_dist: Callable[[T, T], int]):
|
|
|
|
self.operator = operator
|
|
|
|
self.name = name
|
|
|
|
self.test = test
|
|
|
|
self.true_dist = true_dist
|
|
|
|
self.false_dist = false_dist
|
|
|
|
|
|
|
|
|
|
|
|
# Operands for these must both be integers or strings of length 1
|
|
|
|
int_str_ops: list[CmpOp[int | str]] = [
|
|
|
|
CmpOp(operator='<',
|
|
|
|
name='Lt',
|
|
|
|
test=lambda lhs, rhs: lhs < rhs,
|
|
|
|
true_dist=lambda lhs, rhs: lhs - rhs + 1 if lhs >= rhs else 0,
|
|
|
|
false_dist=lambda lhs, rhs: rhs - lhs if lhs < rhs else 0),
|
|
|
|
CmpOp(operator='>',
|
|
|
|
name='Gt',
|
|
|
|
test=lambda lhs, rhs: lhs > rhs,
|
|
|
|
true_dist=lambda lhs, rhs: rhs - lhs + 1 if lhs <= rhs else 0,
|
|
|
|
false_dist=lambda lhs, rhs: lhs - rhs if lhs > rhs else 0),
|
|
|
|
CmpOp(operator='<=',
|
|
|
|
name='LtE',
|
|
|
|
test=lambda lhs, rhs: lhs <= rhs,
|
|
|
|
true_dist=lambda lhs, rhs: lhs - rhs if lhs > rhs else 0,
|
|
|
|
false_dist=lambda lhs, rhs: rhs - lhs + 1 if lhs <= rhs else 0),
|
|
|
|
CmpOp(operator='>=',
|
|
|
|
name='GtE',
|
|
|
|
test=lambda lhs, rhs: lhs >= rhs,
|
|
|
|
true_dist=lambda lhs, rhs: rhs - lhs if lhs < rhs else 0,
|
|
|
|
false_dist=lambda lhs, rhs: lhs - rhs + 1 if lhs >= rhs else 0),
|
|
|
|
CmpOp(operator='==',
|
|
|
|
name='Eq',
|
|
|
|
test=lambda lhs, rhs: lhs == rhs,
|
|
|
|
true_dist=lambda lhs, rhs: abs(lhs - rhs),
|
|
|
|
false_dist=lambda lhs, rhs: 1 if lhs == rhs else 0),
|
|
|
|
CmpOp(operator='!=',
|
|
|
|
name='NotEq',
|
2023-12-11 14:43:53 +00:00
|
|
|
test=lambda lhs, rhs: lhs != rhs,
|
2023-11-13 13:45:51 +00:00
|
|
|
true_dist=lambda lhs, rhs: 1 if lhs == rhs else 0,
|
|
|
|
false_dist=lambda lhs, rhs: abs(lhs - rhs)),
|
|
|
|
]
|
|
|
|
|
|
|
|
int_str_by_name: dict[str, CmpOp[int | str]] = {c.name: c for c in int_str_ops}
|
|
|
|
|
|
|
|
|
|
|
|
def int_str_check(a: any, b: any) -> bool:
|
|
|
|
if type(a) == int and type(b) == int:
|
|
|
|
return True
|
|
|
|
if type(a) != str or type(b) != str:
|
|
|
|
return False
|
2023-11-15 12:32:08 +00:00
|
|
|
return len(a) == 1 and len(b) == 1
|
2023-11-13 13:45:51 +00:00
|
|
|
|
|
|
|
|
|
|
|
def int_str_convert(x: int | str) -> int:
|
|
|
|
if type(x) == int:
|
|
|
|
return x
|
|
|
|
if len(x) == 1:
|
|
|
|
return ord(x)
|
|
|
|
|
|
|
|
raise ValueError("x must be int or len(str) == 1")
|
|
|
|
|
|
|
|
|
|
|
|
# Operands for these must both be strings
|
|
|
|
str_ops: list[CmpOp[str]] = [
|
|
|
|
CmpOp(operator='==',
|
|
|
|
name='Eq',
|
|
|
|
test=lambda lhs, rhs: lhs == rhs,
|
|
|
|
true_dist=lambda lhs, rhs: edit_distance(lhs, rhs),
|
|
|
|
false_dist=lambda lhs, rhs: 1 if lhs == rhs else 0),
|
|
|
|
CmpOp(operator='!=',
|
|
|
|
name='NotEq',
|
2023-12-11 14:43:53 +00:00
|
|
|
test=lambda lhs, rhs: lhs != rhs,
|
2023-11-13 13:45:51 +00:00
|
|
|
true_dist=lambda lhs, rhs: 1 if lhs == rhs else 0,
|
|
|
|
false_dist=lambda lhs, rhs: edit_distance(lhs, rhs)),
|
|
|
|
]
|
|
|
|
|
|
|
|
str_by_name: dict[str, CmpOp[int | str]] = {c.name: c for c in str_ops}
|
|
|
|
|
|
|
|
|
|
|
|
def str_check(a: any, b: any) -> bool:
|
|
|
|
return type(a) == str and type(b) == str
|
|
|
|
|
|
|
|
|
2023-12-09 16:56:04 +00:00
|
|
|
def compute_distances(name: str, lhs: any, rhs: any) -> tuple[int, int, bool]:
|
2023-11-13 13:45:51 +00:00
|
|
|
if int_str_check(lhs, rhs):
|
|
|
|
lhs_int = int_str_convert(lhs)
|
|
|
|
rhs_int = int_str_convert(rhs)
|
|
|
|
|
|
|
|
if name not in int_str_by_name:
|
|
|
|
raise ValueError(f"'{name}' is not a valid CmpOp name for 'int_str' operators")
|
|
|
|
|
|
|
|
op = int_str_by_name[name]
|
2023-12-09 16:56:04 +00:00
|
|
|
return op.true_dist(lhs_int, rhs_int), op.false_dist(lhs_int, rhs_int), op.test(lhs_int, rhs_int)
|
2023-11-13 13:45:51 +00:00
|
|
|
|
|
|
|
if str_check(lhs, rhs):
|
|
|
|
if name not in str_by_name:
|
|
|
|
raise ValueError(f"'{name}' is not a valid CmpOp name for 'str' operators")
|
|
|
|
|
2023-11-15 12:32:08 +00:00
|
|
|
op = str_by_name[name]
|
2023-12-09 16:56:04 +00:00
|
|
|
return op.true_dist(lhs, rhs), op.false_dist(lhs, rhs), op.test(lhs, rhs)
|
2023-11-13 13:45:51 +00:00
|
|
|
|
|
|
|
raise ValueError(f"'{lhs}' and '{rhs}' are not suitable for both 'int_str' and 'str' operators")
|
2023-12-09 16:56:04 +00:00
|
|
|
|
|
|
|
|
|
|
|
def update_map(the_map: dict[int, int], condition_num: int, distance: int):
|
2023-12-11 14:43:53 +00:00
|
|
|
if condition_num in the_map:
|
2023-12-09 16:56:04 +00:00
|
|
|
the_map[condition_num] = min(the_map[condition_num], distance)
|
|
|
|
else:
|
|
|
|
the_map[condition_num] = distance
|
|
|
|
|
|
|
|
|
|
|
|
def update_maps(condition_num, d_true, d_false):
|
|
|
|
global distances_true, distances_false
|
2023-12-11 14:43:53 +00:00
|
|
|
|
2023-12-09 16:56:04 +00:00
|
|
|
update_map(distances_true, condition_num, d_true)
|
2023-12-11 14:43:53 +00:00
|
|
|
if condition_num not in distances_true_all:
|
|
|
|
distances_true_all[condition_num] = [d_true]
|
|
|
|
else:
|
|
|
|
distances_true_all[condition_num].append(d_true)
|
|
|
|
|
2023-12-09 16:56:04 +00:00
|
|
|
update_map(distances_false, condition_num, d_false)
|
2023-12-11 14:43:53 +00:00
|
|
|
if condition_num not in distances_false_all:
|
|
|
|
distances_false_all[condition_num] = [d_false]
|
|
|
|
else:
|
|
|
|
distances_false_all[condition_num].append(d_false)
|
2023-12-09 16:56:04 +00:00
|
|
|
|
|
|
|
|
|
|
|
def in_op(num, lhs, rhs):
|
|
|
|
if isinstance(lhs, str):
|
|
|
|
lhs = ord(lhs)
|
|
|
|
|
|
|
|
minimum = sys.maxsize
|
|
|
|
for elem in rhs.keys():
|
|
|
|
distance = abs(lhs - ord(elem))
|
|
|
|
if distance < minimum:
|
|
|
|
minimum = distance
|
|
|
|
|
|
|
|
distance_true, distance_false = minimum, 1 if minimum == 0 else 0
|
|
|
|
update_maps(num, distance_true, distance_false)
|
|
|
|
return distance_true == 0 # distance == 0 equivalent to actual test by construction
|
|
|
|
|
|
|
|
|
|
|
|
def evaluate_condition(num, op, lhs, rhs):
|
|
|
|
if op == "In":
|
|
|
|
return in_op(num, lhs, rhs)
|
|
|
|
|
|
|
|
distance_true, distance_false, test = compute_distances(op, lhs, rhs)
|
|
|
|
update_maps(num, distance_true, distance_false)
|
|
|
|
return test
|