14 lines
446 B
Python
14 lines
446 B
Python
|
from transformers import BertForSequenceClassification
|
||
|
|
||
|
from src.modelimpl.classifier import bert_classifier
|
||
|
|
||
|
|
||
|
def test_bert_classifier():
|
||
|
# Test that the function returns an instance of BertForSequenceClassification
|
||
|
n_classes = 5
|
||
|
model = bert_classifier(n_classes)
|
||
|
assert isinstance(model, BertForSequenceClassification)
|
||
|
|
||
|
# Test that the model has the correct number of labels
|
||
|
assert model.config.num_labels == n_classes
|