This repository has been archived on 2024-10-22. You can view files and clone it, but cannot push or open issues or pull requests.
soft-analytics-01/tests/test_modelimpl_classifier.py

14 lines
446 B
Python
Raw Normal View History

from transformers import BertForSequenceClassification
from src.modelimpl.classifier import bert_classifier
def test_bert_classifier():
# Test that the function returns an instance of BertForSequenceClassification
n_classes = 5
model = bert_classifier(n_classes)
assert isinstance(model, BertForSequenceClassification)
# Test that the model has the correct number of labels
assert model.config.num_labels == n_classes