from transformers import BertForSequenceClassification from src.modelimpl.classifier import bert_classifier def test_bert_classifier(): # Test that the function returns an instance of BertForSequenceClassification n_classes = 5 model = bert_classifier(n_classes) assert isinstance(model, BertForSequenceClassification) # Test that the model has the correct number of labels assert model.config.num_labels == n_classes