84 lines
2.7 KiB
Python
84 lines
2.7 KiB
Python
|
import pandas as pd
|
||
|
import pytest
|
||
|
import torch
|
||
|
from torch.utils.data import DataLoader
|
||
|
|
||
|
from src.modelimpl.classifier import bert_classifier
|
||
|
from src.modelimpl.dataset import Labelling, SplitData
|
||
|
from src.modelimpl.torch_dataset import Dataset
|
||
|
from src.modelimpl.torch_train import train, print_message, compute_loss_and_acc
|
||
|
|
||
|
|
||
|
@pytest.fixture
|
||
|
def mocked_labelling():
|
||
|
return Labelling({"author_0": 0, "author_1": 1, "author_2": 2})
|
||
|
|
||
|
|
||
|
@pytest.fixture
|
||
|
def mocked_split_data(mocked_labelling) -> tuple[SplitData, SplitData]:
|
||
|
df = pd.DataFrame({
|
||
|
"assignee": ["author_0", "author_1", "author_2", "author_1", "author_0"],
|
||
|
"title_body": [
|
||
|
"cats chase playful fuzzy mice",
|
||
|
"big red ball bounces high",
|
||
|
"happy sun warms cool breeze",
|
||
|
"jumping kids laugh on playground",
|
||
|
"test sentence number 5",
|
||
|
],
|
||
|
}, index=[1, 2, 3, 4, 5])
|
||
|
|
||
|
return (SplitData.from_df(df.loc[[1, 2, 3]], mocked_labelling, 3),
|
||
|
SplitData.from_df(df.loc[[4, 5]], mocked_labelling, 3))
|
||
|
|
||
|
|
||
|
@pytest.fixture
|
||
|
def mocked_data(mocked_split_data: tuple[SplitData, SplitData]):
|
||
|
train_set, val_set = mocked_split_data
|
||
|
return DataLoader(Dataset(train_set), batch_size=2), DataLoader(Dataset(val_set), batch_size=2)
|
||
|
|
||
|
|
||
|
@pytest.fixture
|
||
|
def mocked_model():
|
||
|
return bert_classifier(n_classes=3)
|
||
|
|
||
|
|
||
|
def test_train_without_errors(capfd, mocked_model, mocked_data):
|
||
|
train(mocked_model, mocked_data[0].dataset, mocked_data[1].dataset, learning_rate=0.001, epochs=2, force_cpu=True)
|
||
|
captured = capfd.readouterr()
|
||
|
assert "Epochs: 1" in captured.out
|
||
|
assert "Epochs: 2" in captured.out
|
||
|
|
||
|
|
||
|
def test_print_message(capsys):
|
||
|
class MockDataset:
|
||
|
texts: list[any]
|
||
|
|
||
|
def __init__(self, length: int):
|
||
|
self.texts = [None] * length
|
||
|
|
||
|
# noinspection PyTypeChecker
|
||
|
print_message(epoch_num=1, train_loss=2.0, train_acc=0.7, train_ds=MockDataset(1), val_loss=1.0, val_acc=0.8,
|
||
|
val_ds=MockDataset(1))
|
||
|
|
||
|
captured = capsys.readouterr()
|
||
|
assert "Epochs: 2" in captured.out
|
||
|
assert "Train Loss: 2.000" in captured.out
|
||
|
assert "Train Accuracy: 0.700" in captured.out
|
||
|
assert "Val Loss: 1.000" in captured.out
|
||
|
assert "Val Accuracy: 0.800" in captured.out
|
||
|
|
||
|
|
||
|
def test_compute_loss_and_acc(mocked_model, mocked_data):
|
||
|
train_data, val_data = mocked_data
|
||
|
|
||
|
device = torch.device("cpu")
|
||
|
model = mocked_model
|
||
|
model.return_value = torch.tensor([[0.2, 0.8], [0.5, 0.5]])
|
||
|
|
||
|
val_input, val_label = next(train_data.__iter__())
|
||
|
loss, acc, batch_loss = compute_loss_and_acc(val_label, val_input, device, model)
|
||
|
|
||
|
assert isinstance(loss, float)
|
||
|
assert isinstance(acc, int)
|
||
|
assert isinstance(batch_loss, torch.Tensor)
|