from unittest.mock import MagicMock, patch import pytest from src.modelimpl.auc import build_curve, compute_auc_roc from test_modelimpl_torch_train import mocked_model, mocked_split_data, mocked_labelling @pytest.fixture def mock_classifier(): return MagicMock() def test_build_curve_invalid_path(): with pytest.raises(ValueError, match="path should point to a pytorch model file"): build_curve("invalid_path", force_cpu=True) @patch('src.modelimpl.auc.load_model', return_value=(True, MagicMock(), 3)) @patch('src.modelimpl.auc.Labelling.load', return_value=MagicMock()) @patch('src.modelimpl.auc.SplitData.from_df', return_value=MagicMock()) @patch('src.modelimpl.auc.compute_auc_roc') def test_build_curve_valid_path(mock_compute_auc_roc, mock_from_df, mock_labelling, mock_load_model): build_curve("valid_path.pt", force_cpu=True) mock_load_model.assert_called_once_with("valid_path.pt", None, True, False) mock_compute_auc_roc.assert_called_once() def test_compute_auc_roc(mocked_model, mocked_split_data, mocked_labelling, tmp_path): compute_auc_roc(mocked_model, mocked_split_data[0], 3, mocked_labelling, True, f"{tmp_path}/test_file") assert (tmp_path / "test_file.ovr_curves.png").exists() assert (tmp_path / "test_file.ovr_avg.png").exists() assert (tmp_path / "test_file.auc.txt").exists() (tmp_path / "test_file.ovr_curves.png").unlink() (tmp_path / "test_file.ovr_avg.png").unlink() (tmp_path / "test_file.auc.txt").unlink()