39 lines
1.5 KiB
Python
39 lines
1.5 KiB
Python
from unittest.mock import MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
from src.modelimpl.auc import build_curve, compute_auc_roc
|
|
from test_modelimpl_torch_train import mocked_model, mocked_split_data, mocked_labelling
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_classifier():
|
|
return MagicMock()
|
|
|
|
|
|
def test_build_curve_invalid_path():
|
|
with pytest.raises(ValueError, match="path should point to a pytorch model file"):
|
|
build_curve("invalid_path", force_cpu=True)
|
|
|
|
|
|
@patch('src.modelimpl.auc.load_model', return_value=(True, MagicMock(), 3))
|
|
@patch('src.modelimpl.auc.Labelling.load', return_value=MagicMock())
|
|
@patch('src.modelimpl.auc.SplitData.from_df', return_value=MagicMock())
|
|
@patch('src.modelimpl.auc.compute_auc_roc')
|
|
def test_build_curve_valid_path(mock_compute_auc_roc, mock_from_df, mock_labelling, mock_load_model):
|
|
build_curve("valid_path.pt", force_cpu=True)
|
|
mock_load_model.assert_called_once_with("valid_path.pt", None, True, False)
|
|
mock_compute_auc_roc.assert_called_once()
|
|
|
|
|
|
def test_compute_auc_roc(mocked_model, mocked_split_data, mocked_labelling, tmp_path):
|
|
compute_auc_roc(mocked_model, mocked_split_data[0], 3, mocked_labelling, True,
|
|
f"{tmp_path}/test_file")
|
|
|
|
assert (tmp_path / "test_file.ovr_curves.png").exists()
|
|
assert (tmp_path / "test_file.ovr_avg.png").exists()
|
|
assert (tmp_path / "test_file.auc.txt").exists()
|
|
|
|
(tmp_path / "test_file.ovr_curves.png").unlink()
|
|
(tmp_path / "test_file.ovr_avg.png").unlink()
|
|
(tmp_path / "test_file.auc.txt").unlink()
|