13 lines
446 B
Python
13 lines
446 B
Python
from transformers import BertForSequenceClassification
|
|
|
|
from src.modelimpl.classifier import bert_classifier
|
|
|
|
|
|
def test_bert_classifier():
|
|
# Test that the function returns an instance of BertForSequenceClassification
|
|
n_classes = 5
|
|
model = bert_classifier(n_classes)
|
|
assert isinstance(model, BertForSequenceClassification)
|
|
|
|
# Test that the model has the correct number of labels
|
|
assert model.config.num_labels == n_classes
|