import pandas as pd import pytest from src.modelimpl.classifier import bert_classifier from src.modelimpl.dataset import tokenizer, SplitData, Labelling from src.modelimpl.evaluate import predict, evaluate, predict_top_k, PredictionResult from src.modelimpl.torch_dataset import Dataset class MockSplitData: def __init__(self, labels, texts): self.labels = labels self.texts = texts def test_predict(): # Create a sample model and dataset model = bert_classifier(n_classes=2) labels = [0, 1, 1, 0] texts = [ "cats chase playful fuzzy mice", "big red ball bounces high", "happy sun warms cool breeze", "jumping kids laugh on playground", ] texts = [tokenizer(text, padding='max_length', max_length=512, truncation=True, return_tensors='pt') for text in texts] split_data = MockSplitData(labels, texts) dataset = Dataset(split_data) # Test predict function predictions = predict(model, dataset, top_n=2, force_cpu=True) # Check the length of predictions assert len(predictions) == len(labels) # Check the format of PredictionResult instances for result in predictions: assert isinstance(result, PredictionResult) assert len(result.top_values) == 2 assert len(result.top_indices) == 2 assert isinstance(result.truth_idx, int) # Test case for evaluate function def test_evaluate(capsys): # Create a sample model and dataset model = bert_classifier(n_classes=2) labels = [0, 1, 1, 0] texts = [ "cats chase playful fuzzy mice", "big red ball bounces high", "happy sun warms cool breeze", "jumping kids laugh on playground", ] texts = [tokenizer(text, padding='max_length', max_length=512, truncation=True, return_tensors='pt') for text in texts] split_data = MockSplitData(labels, texts) dataset = Dataset(split_data) # Test evaluate function evaluate(model, dataset, force_cpu=True) # Capture the printed output and check the format captured = capsys.readouterr() assert "recommendations:" in captured.out # Test case for predict_top_k function def test_predict_top_k(): # Create a sample model and dataset model = bert_classifier(n_classes=2) df = pd.DataFrame({ "assignee": ["author_0", "author_1", "author_1", "author_0"], "title_body": [ "cats chase playful fuzzy mice", "big red ball bounces high", "happy sun warms cool breeze", "jumping kids laugh on playground", ], }, index=[1, 2, 3, 4]) labels = Labelling({ "author_0": 0, "author_1": 1 }) split_data = SplitData.from_df(df, labels, 2) issue_id = 1 # Test predict_top_k function result = predict_top_k(model, split_data, issue_id, top_n=2, force_cpu=True) # Check the format of PredictionResult instance assert isinstance(result, PredictionResult) assert len(result.top_values) == 2 assert len(result.top_indices) == 2 assert isinstance(result.truth_idx, int) # Check the correctness of assert statement in the function with pytest.raises(ValueError): predict_top_k(model, split_data, issue_id=99, top_n=2, force_cpu=True)