This repository has been archived on 2024-10-22. You can view files and clone it, but cannot push or open issues or pull requests.
soft-analytics-01/tests/test_modelimpl_auc.py

40 lines
1.5 KiB
Python
Raw Permalink Normal View History

from unittest.mock import MagicMock, patch
import pytest
from src.modelimpl.auc import build_curve, compute_auc_roc
from test_modelimpl_torch_train import mocked_model, mocked_split_data, mocked_labelling
@pytest.fixture
def mock_classifier():
return MagicMock()
def test_build_curve_invalid_path():
with pytest.raises(ValueError, match="path should point to a pytorch model file"):
build_curve("invalid_path", force_cpu=True)
@patch('src.modelimpl.auc.load_model', return_value=(True, MagicMock(), 3))
@patch('src.modelimpl.auc.Labelling.load', return_value=MagicMock())
@patch('src.modelimpl.auc.SplitData.from_df', return_value=MagicMock())
@patch('src.modelimpl.auc.compute_auc_roc')
def test_build_curve_valid_path(mock_compute_auc_roc, mock_from_df, mock_labelling, mock_load_model):
build_curve("valid_path.pt", force_cpu=True)
mock_load_model.assert_called_once_with("valid_path.pt", None, True, False)
mock_compute_auc_roc.assert_called_once()
def test_compute_auc_roc(mocked_model, mocked_split_data, mocked_labelling, tmp_path):
compute_auc_roc(mocked_model, mocked_split_data[0], 3, mocked_labelling, True,
f"{tmp_path}/test_file")
assert (tmp_path / "test_file.ovr_curves.png").exists()
assert (tmp_path / "test_file.ovr_avg.png").exists()
assert (tmp_path / "test_file.auc.txt").exists()
(tmp_path / "test_file.ovr_curves.png").unlink()
(tmp_path / "test_file.ovr_avg.png").unlink()
(tmp_path / "test_file.auc.txt").unlink()