This repository has been archived on 2024-10-22. You can view files and clone it, but cannot push or open issues or pull requests.
soft-analytics-02/test/test_evaluate.py
Claudio Maggioni a4ceee8716 Final version of the project
History has been rewritten to delete large files in repo
2024-01-03 15:28:43 +01:00

34 lines
1.1 KiB
Python

from train.dataset import TOKENIZER
from train.evaluate import compute_accuracy
def test_compute_accuracy():
batch = {'labels': [TOKENIZER.encode("label 1"), TOKENIZER.encode("label 2")], 'input_ids': [[1,2],[3,4]]}
outputs = [TOKENIZER.encode("label 1"), TOKENIZER.encode("label 2")]
result = compute_accuracy(outputs, batch)
correct_predictions, total_predictions, _ = result
print(result)
assert isinstance(result, tuple)
assert isinstance(correct_predictions, int)
assert isinstance(total_predictions, int)
assert correct_predictions == 2
assert total_predictions == 2
def test_compute_accuracy_none():
batch = {'labels': [[-100], TOKENIZER.encode("label 2")], 'input_ids': [[5,6], [7,8]]}
outputs = [TOKENIZER.encode("label 1"), TOKENIZER.encode("label 2")]
result = compute_accuracy(outputs, batch)
correct_predictions, total_predictions, _ = result
print(result)
assert isinstance(result, tuple)
assert isinstance(correct_predictions, int)
assert isinstance(total_predictions, int)
assert correct_predictions == 1
assert total_predictions == 1