This repository has been archived on 2024-10-22. You can view files and clone it, but cannot push or open issues or pull requests.
soft-analytics-01/tests/test_modelimpl_load.py

51 lines
1.9 KiB
Python
Raw Normal View History

import os
import pytest
import torch
from src.modelimpl.classifier import bert_classifier, Classifier
from src.modelimpl.load import load_model
@pytest.fixture
def model_instance():
return bert_classifier(n_classes=4)
@pytest.fixture
def model_path(tmpdir):
temp_model_path = os.path.join(tmpdir, "test_model.pt")
return temp_model_path
def test_load_model_with_valid_path(model_path):
result, model, classes = load_model(model_path, label_range=(1, 5), force_cpu=False, force_retrain=False)
assert result is False # The model should not be already trained
assert isinstance(model, Classifier)
assert classes == 4 # The range (1, 5) implies 4 classes
def test_load_model_with_invalid_path():
with pytest.raises(ValueError, match="path should point to a pytorch model file"):
load_model("invalid_path.txt", label_range=(1, 5), force_cpu=False, force_retrain=False)
def test_load_model_with_force_retrain(model_path):
result, model, classes = load_model(model_path, label_range=(1, 5), force_cpu=False, force_retrain=True)
assert result is False # The model should not be already trained, but force_retrain is True
def test_load_model_with_force_cpu(model_path):
result, model, classes = load_model(model_path, label_range=(1, 5), force_cpu=True, force_retrain=False)
assert result is False # The model should not be already trained
assert isinstance(model, Classifier)
assert not torch.cuda.is_available() # CUDA should not be available
def test_load_model_with_already_trained_model(model_path, model_instance):
torch.save(model_instance.state_dict(), model_path)
result, model, classes = load_model(model_path, label_range=(1, 5), force_cpu=False, force_retrain=False)
assert result is True # The model should be already trained
assert isinstance(model, Classifier)
assert classes == 4 # The range (1, 5) implies 4 classes