This repository has been archived on 2024-10-22. You can view files and clone it, but cannot push or open issues or pull requests.
soft-analytics-01/tests/test_modelimpl_evaluate.py

108 lines
3.3 KiB
Python
Raw Normal View History

import pandas as pd
import pytest
from src.modelimpl.classifier import bert_classifier
from src.modelimpl.dataset import tokenizer, SplitData, Labelling
from src.modelimpl.evaluate import predict, evaluate, predict_top_k, PredictionResult
from src.modelimpl.torch_dataset import Dataset
class MockSplitData:
def __init__(self, labels, texts):
self.labels = labels
self.texts = texts
def test_predict():
# Create a sample model and dataset
model = bert_classifier(n_classes=2)
labels = [0, 1, 1, 0]
texts = [
"cats chase playful fuzzy mice",
"big red ball bounces high",
"happy sun warms cool breeze",
"jumping kids laugh on playground",
]
texts = [tokenizer(text, padding='max_length', max_length=512, truncation=True,
return_tensors='pt') for text in texts]
split_data = MockSplitData(labels, texts)
dataset = Dataset(split_data)
# Test predict function
predictions = predict(model, dataset, top_n=2, force_cpu=True)
# Check the length of predictions
assert len(predictions) == len(labels)
# Check the format of PredictionResult instances
for result in predictions:
assert isinstance(result, PredictionResult)
assert len(result.top_values) == 2
assert len(result.top_indices) == 2
assert isinstance(result.truth_idx, int)
# Test case for evaluate function
def test_evaluate(capsys):
# Create a sample model and dataset
model = bert_classifier(n_classes=2)
labels = [0, 1, 1, 0]
texts = [
"cats chase playful fuzzy mice",
"big red ball bounces high",
"happy sun warms cool breeze",
"jumping kids laugh on playground",
]
texts = [tokenizer(text, padding='max_length', max_length=512, truncation=True,
return_tensors='pt') for text in texts]
split_data = MockSplitData(labels, texts)
dataset = Dataset(split_data)
# Test evaluate function
evaluate(model, dataset, force_cpu=True)
# Capture the printed output and check the format
captured = capsys.readouterr()
assert "recommendations:" in captured.out
# Test case for predict_top_k function
def test_predict_top_k():
# Create a sample model and dataset
model = bert_classifier(n_classes=2)
df = pd.DataFrame({
"assignee": ["author_0", "author_1", "author_1", "author_0"],
"title_body": [
"cats chase playful fuzzy mice",
"big red ball bounces high",
"happy sun warms cool breeze",
"jumping kids laugh on playground",
],
}, index=[1, 2, 3, 4])
labels = Labelling({
"author_0": 0,
"author_1": 1
})
split_data = SplitData.from_df(df, labels, 2)
issue_id = 1
# Test predict_top_k function
result = predict_top_k(model, split_data, issue_id, top_n=2, force_cpu=True)
# Check the format of PredictionResult instance
assert isinstance(result, PredictionResult)
assert len(result.top_values) == 2
assert len(result.top_indices) == 2
assert isinstance(result.truth_idx, int)
# Check the correctness of assert statement in the function
with pytest.raises(ValueError):
predict_top_k(model, split_data, issue_id=99, top_n=2, force_cpu=True)