94 lines
3.2 KiB
Python
94 lines
3.2 KiB
Python
import pandas as pd
|
|
import pytest
|
|
import swifter
|
|
import torch
|
|
from torch.utils.data import DataLoader
|
|
|
|
from train.dataset import (TOKENIZER, MAX_TOKEN_LENGTH, PythonCodeDataset, MaskedIfDataset, decode_tokenized,
|
|
PRETRAIN_MLM_PROB, BATCH_SIZE, build_pretrain_dataloader, build_fine_tune_dataloader)
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_pretrain_data():
|
|
data = {'source': ['if a > 2: pass', 'if b <= 4: pass'],
|
|
'other_column': [1, 2]}
|
|
|
|
return pd.DataFrame(data)
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_fine_tune_dataloader():
|
|
data = {'source': ['if a > 2: pass', 'if b <= 4: pass'],
|
|
'other_column': [1, 2]}
|
|
|
|
return pd.DataFrame(data)
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_fine_tune_data():
|
|
data = {'masked_code': ['if a > 2: pass', 'if b <= 4: pass'],
|
|
'ground_truth': ['if a > 2: pass', 'if b <= 4: pass']}
|
|
|
|
return pd.DataFrame(data)
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_tokenized_output():
|
|
return [1234, 5678]
|
|
|
|
|
|
def test_decode_tokenized(mock_tokenized_output):
|
|
decoded_output = decode_tokenized(mock_tokenized_output)
|
|
expected_output = " msg comments"
|
|
assert decoded_output == expected_output
|
|
|
|
mock_tokenized_output_with_padding = [-100]
|
|
decoded_output_with_padding = decode_tokenized(mock_tokenized_output_with_padding)
|
|
expected_output_with_padding = None
|
|
assert decoded_output_with_padding == expected_output_with_padding
|
|
|
|
|
|
def test_build_pretrain_dataloader(mock_pretrain_data):
|
|
dataloader = build_pretrain_dataloader(mock_pretrain_data)
|
|
|
|
assert isinstance(dataloader, DataLoader)
|
|
assert dataloader.batch_size == BATCH_SIZE
|
|
assert isinstance(dataloader.dataset, PythonCodeDataset)
|
|
assert dataloader.dataset.tokenizer == TOKENIZER
|
|
assert dataloader.dataset.data.equals(mock_pretrain_data)
|
|
assert dataloader.collate_fn.tokenizer == TOKENIZER
|
|
assert dataloader.collate_fn.mlm_probability == PRETRAIN_MLM_PROB
|
|
assert dataloader.collate_fn.mlm == True
|
|
|
|
|
|
def test_build_fine_tune_dataloader(mock_fine_tune_dataloader):
|
|
train_dataloader = build_fine_tune_dataloader(mock_fine_tune_dataloader, 'train')
|
|
|
|
assert isinstance(train_dataloader, DataLoader)
|
|
assert train_dataloader.batch_size == BATCH_SIZE
|
|
assert isinstance(train_dataloader.dataset, PythonCodeDataset)
|
|
assert train_dataloader.dataset.tokenizer == TOKENIZER
|
|
|
|
|
|
def test_python_code_dataset(mock_pretrain_data):
|
|
dataset = PythonCodeDataset(TOKENIZER, mock_pretrain_data, MAX_TOKEN_LENGTH)
|
|
sample = dataset[0]
|
|
|
|
assert len(dataset) == len(mock_pretrain_data)
|
|
assert 'input_ids' in sample
|
|
assert 'attention_mask' in sample
|
|
assert sample['input_ids'].shape == torch.Size([MAX_TOKEN_LENGTH])
|
|
assert sample['attention_mask'].shape == torch.Size([MAX_TOKEN_LENGTH])
|
|
|
|
|
|
def test_masked_if_dataset(mock_fine_tune_data):
|
|
dataset = MaskedIfDataset(TOKENIZER, mock_fine_tune_data, MAX_TOKEN_LENGTH)
|
|
sample = dataset[0]
|
|
|
|
assert len(dataset) == len(mock_fine_tune_data)
|
|
assert 'input_ids' in sample
|
|
assert 'attention_mask' in sample
|
|
assert 'labels' in sample
|
|
assert sample['input_ids'].shape == torch.Size([MAX_TOKEN_LENGTH])
|
|
assert sample['attention_mask'].shape == torch.Size([MAX_TOKEN_LENGTH])
|
|
assert sample['labels'].shape == torch.Size([MAX_TOKEN_LENGTH])
|