This repository has been archived on 2024-10-22. You can view files and clone it, but cannot push or open issues or pull requests.
soft-analytics-01/tests/test_modelimpl_classifier.py
Claudio Maggioni 07232eddcc Final version of the bug-triaging project
Commit history has been discarded to remove large files from the repo.
2024-01-03 15:22:56 +01:00

13 lines
446 B
Python

from transformers import BertForSequenceClassification
from src.modelimpl.classifier import bert_classifier
def test_bert_classifier():
# Test that the function returns an instance of BertForSequenceClassification
n_classes = 5
model = bert_classifier(n_classes)
assert isinstance(model, BertForSequenceClassification)
# Test that the model has the correct number of labels
assert model.config.num_labels == n_classes