This repository has been archived on 2024-10-22. You can view files and clone it, but cannot push or open issues or pull requests.
soft-analytics-01/tests/test_modelimpl_dataset.py

67 lines
2.1 KiB
Python
Raw Permalink Normal View History

import os
import pandas as pd
import pytest
from src.modelimpl.dataset import prepare_input, load_df, compute_labels, Labelling, SplitData, df_validation_split, \
Datasets
@pytest.fixture
def sample_dataframe():
return pd.DataFrame({
'id': [1, 2],
'title': ['Title1', 'Title2'],
'body': ['Body1', 'Body2'],
'title_body': ['Title1\nBody1', 'Title2\nBody2'],
'assignee': ['A', 'B']
})
def test_prepare_input(sample_dataframe):
result_df = prepare_input(sample_dataframe)
assert list(result_df.columns) == ['title_body', 'assignee']
expected_title_body = ['Title1\nBody1', 'Title2\nBody2']
assert result_df['title_body'].tolist() == expected_title_body
assert result_df['assignee'].tolist() == ['A', 'B']
def test_load_df(sample_dataframe, tmpdir):
# Save sample DataFrame to a CSV file
csv_filename = os.path.join(tmpdir, 'sample_issues.csv')
sample_dataframe.to_csv(csv_filename, index=False)
result_df = load_df(csv_filename)
assert list(result_df.columns) == ['title_body', 'assignee']
assert len(result_df) == len(sample_dataframe)
def test_compute_labels():
sample_frames = [pd.DataFrame({'assignee': ['A', 'B', 'C']}), pd.DataFrame({'assignee': ['B', 'C', 'D']})]
labels_dict, num_bounds = compute_labels(sample_frames)
assert labels_dict == {'A': 0, 'B': 1, 'C': 2, 'D': 3}
assert num_bounds == [0, 3, 4]
def test_labelling_methods(tmpdir):
labels = {'A': 0, 'B': 1, 'C': 2}
labelling = Labelling(labels)
filename = os.path.join(tmpdir, 'test_labels.csv')
labelling.save(filename)
loaded_labelling = Labelling.load(filename)
assert labelling.labels == loaded_labelling.labels
def test_split_data_methods(sample_dataframe):
labels = Labelling({'A': 0, 'B': 1})
split_data = SplitData.from_df(sample_dataframe, labels, 1)
assert len(split_data) == len(sample_dataframe)
def test_df_validation_split(sample_dataframe):
df_train, df_val = df_validation_split(sample_dataframe)
assert len(df_train) > 0
assert len(df_val) > 0
assert len(df_train) + len(df_val) == len(sample_dataframe)