67 lines
2.1 KiB
Python
67 lines
2.1 KiB
Python
|
import os
|
||
|
import pandas as pd
|
||
|
import pytest
|
||
|
from src.modelimpl.dataset import prepare_input, load_df, compute_labels, Labelling, SplitData, df_validation_split, \
|
||
|
Datasets
|
||
|
|
||
|
|
||
|
@pytest.fixture
|
||
|
def sample_dataframe():
|
||
|
return pd.DataFrame({
|
||
|
'id': [1, 2],
|
||
|
'title': ['Title1', 'Title2'],
|
||
|
'body': ['Body1', 'Body2'],
|
||
|
'title_body': ['Title1\nBody1', 'Title2\nBody2'],
|
||
|
'assignee': ['A', 'B']
|
||
|
})
|
||
|
|
||
|
|
||
|
def test_prepare_input(sample_dataframe):
|
||
|
result_df = prepare_input(sample_dataframe)
|
||
|
assert list(result_df.columns) == ['title_body', 'assignee']
|
||
|
expected_title_body = ['Title1\nBody1', 'Title2\nBody2']
|
||
|
assert result_df['title_body'].tolist() == expected_title_body
|
||
|
assert result_df['assignee'].tolist() == ['A', 'B']
|
||
|
|
||
|
|
||
|
def test_load_df(sample_dataframe, tmpdir):
|
||
|
# Save sample DataFrame to a CSV file
|
||
|
csv_filename = os.path.join(tmpdir, 'sample_issues.csv')
|
||
|
sample_dataframe.to_csv(csv_filename, index=False)
|
||
|
|
||
|
result_df = load_df(csv_filename)
|
||
|
assert list(result_df.columns) == ['title_body', 'assignee']
|
||
|
assert len(result_df) == len(sample_dataframe)
|
||
|
|
||
|
|
||
|
def test_compute_labels():
|
||
|
sample_frames = [pd.DataFrame({'assignee': ['A', 'B', 'C']}), pd.DataFrame({'assignee': ['B', 'C', 'D']})]
|
||
|
|
||
|
labels_dict, num_bounds = compute_labels(sample_frames)
|
||
|
assert labels_dict == {'A': 0, 'B': 1, 'C': 2, 'D': 3}
|
||
|
assert num_bounds == [0, 3, 4]
|
||
|
|
||
|
|
||
|
def test_labelling_methods(tmpdir):
|
||
|
labels = {'A': 0, 'B': 1, 'C': 2}
|
||
|
labelling = Labelling(labels)
|
||
|
|
||
|
filename = os.path.join(tmpdir, 'test_labels.csv')
|
||
|
labelling.save(filename)
|
||
|
loaded_labelling = Labelling.load(filename)
|
||
|
assert labelling.labels == loaded_labelling.labels
|
||
|
|
||
|
|
||
|
def test_split_data_methods(sample_dataframe):
|
||
|
labels = Labelling({'A': 0, 'B': 1})
|
||
|
split_data = SplitData.from_df(sample_dataframe, labels, 1)
|
||
|
|
||
|
assert len(split_data) == len(sample_dataframe)
|
||
|
|
||
|
|
||
|
def test_df_validation_split(sample_dataframe):
|
||
|
df_train, df_val = df_validation_split(sample_dataframe)
|
||
|
assert len(df_train) > 0
|
||
|
assert len(df_val) > 0
|
||
|
assert len(df_train) + len(df_val) == len(sample_dataframe)
|