50 lines
1.9 KiB
Python
50 lines
1.9 KiB
Python
import os
|
|
|
|
import pytest
|
|
import torch
|
|
|
|
from src.modelimpl.classifier import bert_classifier, Classifier
|
|
from src.modelimpl.load import load_model
|
|
|
|
|
|
@pytest.fixture
|
|
def model_instance():
|
|
return bert_classifier(n_classes=4)
|
|
|
|
|
|
@pytest.fixture
|
|
def model_path(tmpdir):
|
|
temp_model_path = os.path.join(tmpdir, "test_model.pt")
|
|
return temp_model_path
|
|
|
|
|
|
def test_load_model_with_valid_path(model_path):
|
|
result, model, classes = load_model(model_path, label_range=(1, 5), force_cpu=False, force_retrain=False)
|
|
assert result is False # The model should not be already trained
|
|
assert isinstance(model, Classifier)
|
|
assert classes == 4 # The range (1, 5) implies 4 classes
|
|
|
|
|
|
def test_load_model_with_invalid_path():
|
|
with pytest.raises(ValueError, match="path should point to a pytorch model file"):
|
|
load_model("invalid_path.txt", label_range=(1, 5), force_cpu=False, force_retrain=False)
|
|
|
|
|
|
def test_load_model_with_force_retrain(model_path):
|
|
result, model, classes = load_model(model_path, label_range=(1, 5), force_cpu=False, force_retrain=True)
|
|
assert result is False # The model should not be already trained, but force_retrain is True
|
|
|
|
|
|
def test_load_model_with_force_cpu(model_path):
|
|
result, model, classes = load_model(model_path, label_range=(1, 5), force_cpu=True, force_retrain=False)
|
|
assert result is False # The model should not be already trained
|
|
assert isinstance(model, Classifier)
|
|
assert not torch.cuda.is_available() # CUDA should not be available
|
|
|
|
|
|
def test_load_model_with_already_trained_model(model_path, model_instance):
|
|
torch.save(model_instance.state_dict(), model_path)
|
|
result, model, classes = load_model(model_path, label_range=(1, 5), force_cpu=False, force_retrain=False)
|
|
assert result is True # The model should be already trained
|
|
assert isinstance(model, Classifier)
|
|
assert classes == 4 # The range (1, 5) implies 4 classes
|