import os import pytest import torch from src.modelimpl.classifier import bert_classifier, Classifier from src.modelimpl.load import load_model @pytest.fixture def model_instance(): return bert_classifier(n_classes=4) @pytest.fixture def model_path(tmpdir): temp_model_path = os.path.join(tmpdir, "test_model.pt") return temp_model_path def test_load_model_with_valid_path(model_path): result, model, classes = load_model(model_path, label_range=(1, 5), force_cpu=False, force_retrain=False) assert result is False # The model should not be already trained assert isinstance(model, Classifier) assert classes == 4 # The range (1, 5) implies 4 classes def test_load_model_with_invalid_path(): with pytest.raises(ValueError, match="path should point to a pytorch model file"): load_model("invalid_path.txt", label_range=(1, 5), force_cpu=False, force_retrain=False) def test_load_model_with_force_retrain(model_path): result, model, classes = load_model(model_path, label_range=(1, 5), force_cpu=False, force_retrain=True) assert result is False # The model should not be already trained, but force_retrain is True def test_load_model_with_force_cpu(model_path): result, model, classes = load_model(model_path, label_range=(1, 5), force_cpu=True, force_retrain=False) assert result is False # The model should not be already trained assert isinstance(model, Classifier) assert not torch.cuda.is_available() # CUDA should not be available def test_load_model_with_already_trained_model(model_path, model_instance): torch.save(model_instance.state_dict(), model_path) result, model, classes = load_model(model_path, label_range=(1, 5), force_cpu=False, force_retrain=False) assert result is True # The model should be already trained assert isinstance(model, Classifier) assert classes == 4 # The range (1, 5) implies 4 classes