This repository has been archived on 2024-10-22. You can view files and clone it, but cannot push or open issues or pull requests.
soft-analytics-01/tests/test_modelimpl_torch_train.py
Claudio Maggioni 07232eddcc Final version of the bug-triaging project
Commit history has been discarded to remove large files from the repo.
2024-01-03 15:22:56 +01:00

83 lines
2.7 KiB
Python

import pandas as pd
import pytest
import torch
from torch.utils.data import DataLoader
from src.modelimpl.classifier import bert_classifier
from src.modelimpl.dataset import Labelling, SplitData
from src.modelimpl.torch_dataset import Dataset
from src.modelimpl.torch_train import train, print_message, compute_loss_and_acc
@pytest.fixture
def mocked_labelling():
return Labelling({"author_0": 0, "author_1": 1, "author_2": 2})
@pytest.fixture
def mocked_split_data(mocked_labelling) -> tuple[SplitData, SplitData]:
df = pd.DataFrame({
"assignee": ["author_0", "author_1", "author_2", "author_1", "author_0"],
"title_body": [
"cats chase playful fuzzy mice",
"big red ball bounces high",
"happy sun warms cool breeze",
"jumping kids laugh on playground",
"test sentence number 5",
],
}, index=[1, 2, 3, 4, 5])
return (SplitData.from_df(df.loc[[1, 2, 3]], mocked_labelling, 3),
SplitData.from_df(df.loc[[4, 5]], mocked_labelling, 3))
@pytest.fixture
def mocked_data(mocked_split_data: tuple[SplitData, SplitData]):
train_set, val_set = mocked_split_data
return DataLoader(Dataset(train_set), batch_size=2), DataLoader(Dataset(val_set), batch_size=2)
@pytest.fixture
def mocked_model():
return bert_classifier(n_classes=3)
def test_train_without_errors(capfd, mocked_model, mocked_data):
train(mocked_model, mocked_data[0].dataset, mocked_data[1].dataset, learning_rate=0.001, epochs=2, force_cpu=True)
captured = capfd.readouterr()
assert "Epochs: 1" in captured.out
assert "Epochs: 2" in captured.out
def test_print_message(capsys):
class MockDataset:
texts: list[any]
def __init__(self, length: int):
self.texts = [None] * length
# noinspection PyTypeChecker
print_message(epoch_num=1, train_loss=2.0, train_acc=0.7, train_ds=MockDataset(1), val_loss=1.0, val_acc=0.8,
val_ds=MockDataset(1))
captured = capsys.readouterr()
assert "Epochs: 2" in captured.out
assert "Train Loss: 2.000" in captured.out
assert "Train Accuracy: 0.700" in captured.out
assert "Val Loss: 1.000" in captured.out
assert "Val Accuracy: 0.800" in captured.out
def test_compute_loss_and_acc(mocked_model, mocked_data):
train_data, val_data = mocked_data
device = torch.device("cpu")
model = mocked_model
model.return_value = torch.tensor([[0.2, 0.8], [0.5, 0.5]])
val_input, val_label = next(train_data.__iter__())
loss, acc, batch_loss = compute_loss_and_acc(val_label, val_input, device, model)
assert isinstance(loss, float)
assert isinstance(acc, int)
assert isinstance(batch_loss, torch.Tensor)