34 lines
1.1 KiB
Python
34 lines
1.1 KiB
Python
from train.dataset import TOKENIZER
|
|
from train.evaluate import compute_accuracy
|
|
|
|
|
|
def test_compute_accuracy():
|
|
batch = {'labels': [TOKENIZER.encode("label 1"), TOKENIZER.encode("label 2")], 'input_ids': [[1,2],[3,4]]}
|
|
outputs = [TOKENIZER.encode("label 1"), TOKENIZER.encode("label 2")]
|
|
|
|
result = compute_accuracy(outputs, batch)
|
|
correct_predictions, total_predictions, _ = result
|
|
|
|
print(result)
|
|
|
|
assert isinstance(result, tuple)
|
|
assert isinstance(correct_predictions, int)
|
|
assert isinstance(total_predictions, int)
|
|
assert correct_predictions == 2
|
|
assert total_predictions == 2
|
|
|
|
|
|
def test_compute_accuracy_none():
|
|
batch = {'labels': [[-100], TOKENIZER.encode("label 2")], 'input_ids': [[5,6], [7,8]]}
|
|
outputs = [TOKENIZER.encode("label 1"), TOKENIZER.encode("label 2")]
|
|
|
|
result = compute_accuracy(outputs, batch)
|
|
correct_predictions, total_predictions, _ = result
|
|
|
|
print(result)
|
|
|
|
assert isinstance(result, tuple)
|
|
assert isinstance(correct_predictions, int)
|
|
assert isinstance(total_predictions, int)
|
|
assert correct_predictions == 1
|
|
assert total_predictions == 1
|