107 lines
3.3 KiB
Python
107 lines
3.3 KiB
Python
import pandas as pd
|
|
import pytest
|
|
|
|
from src.modelimpl.classifier import bert_classifier
|
|
from src.modelimpl.dataset import tokenizer, SplitData, Labelling
|
|
from src.modelimpl.evaluate import predict, evaluate, predict_top_k, PredictionResult
|
|
from src.modelimpl.torch_dataset import Dataset
|
|
|
|
|
|
class MockSplitData:
|
|
def __init__(self, labels, texts):
|
|
self.labels = labels
|
|
self.texts = texts
|
|
|
|
|
|
def test_predict():
|
|
# Create a sample model and dataset
|
|
model = bert_classifier(n_classes=2)
|
|
labels = [0, 1, 1, 0]
|
|
|
|
texts = [
|
|
"cats chase playful fuzzy mice",
|
|
"big red ball bounces high",
|
|
"happy sun warms cool breeze",
|
|
"jumping kids laugh on playground",
|
|
]
|
|
texts = [tokenizer(text, padding='max_length', max_length=512, truncation=True,
|
|
return_tensors='pt') for text in texts]
|
|
|
|
split_data = MockSplitData(labels, texts)
|
|
dataset = Dataset(split_data)
|
|
|
|
# Test predict function
|
|
predictions = predict(model, dataset, top_n=2, force_cpu=True)
|
|
|
|
# Check the length of predictions
|
|
assert len(predictions) == len(labels)
|
|
|
|
# Check the format of PredictionResult instances
|
|
for result in predictions:
|
|
assert isinstance(result, PredictionResult)
|
|
assert len(result.top_values) == 2
|
|
assert len(result.top_indices) == 2
|
|
assert isinstance(result.truth_idx, int)
|
|
|
|
|
|
# Test case for evaluate function
|
|
def test_evaluate(capsys):
|
|
# Create a sample model and dataset
|
|
model = bert_classifier(n_classes=2)
|
|
labels = [0, 1, 1, 0]
|
|
|
|
texts = [
|
|
"cats chase playful fuzzy mice",
|
|
"big red ball bounces high",
|
|
"happy sun warms cool breeze",
|
|
"jumping kids laugh on playground",
|
|
]
|
|
texts = [tokenizer(text, padding='max_length', max_length=512, truncation=True,
|
|
return_tensors='pt') for text in texts]
|
|
|
|
split_data = MockSplitData(labels, texts)
|
|
dataset = Dataset(split_data)
|
|
|
|
# Test evaluate function
|
|
evaluate(model, dataset, force_cpu=True)
|
|
|
|
# Capture the printed output and check the format
|
|
captured = capsys.readouterr()
|
|
assert "recommendations:" in captured.out
|
|
|
|
|
|
# Test case for predict_top_k function
|
|
def test_predict_top_k():
|
|
# Create a sample model and dataset
|
|
model = bert_classifier(n_classes=2)
|
|
|
|
df = pd.DataFrame({
|
|
"assignee": ["author_0", "author_1", "author_1", "author_0"],
|
|
"title_body": [
|
|
"cats chase playful fuzzy mice",
|
|
"big red ball bounces high",
|
|
"happy sun warms cool breeze",
|
|
"jumping kids laugh on playground",
|
|
],
|
|
}, index=[1, 2, 3, 4])
|
|
|
|
labels = Labelling({
|
|
"author_0": 0,
|
|
"author_1": 1
|
|
})
|
|
|
|
split_data = SplitData.from_df(df, labels, 2)
|
|
issue_id = 1
|
|
|
|
# Test predict_top_k function
|
|
result = predict_top_k(model, split_data, issue_id, top_n=2, force_cpu=True)
|
|
|
|
# Check the format of PredictionResult instance
|
|
assert isinstance(result, PredictionResult)
|
|
assert len(result.top_values) == 2
|
|
assert len(result.top_indices) == 2
|
|
assert isinstance(result.truth_idx, int)
|
|
|
|
# Check the correctness of assert statement in the function
|
|
with pytest.raises(ValueError):
|
|
predict_top_k(model, split_data, issue_id=99, top_n=2, force_cpu=True)
|