adapting bert unitest interface

This commit is contained in:
jiaruifang
2022-03-09 11:26:10 +08:00
committed by Frank Lee
parent 7977422aeb
commit 4d94cd513e
2 changed files with 18 additions and 6 deletions

View File

@@ -48,9 +48,21 @@ def get_training_components():
num_hidden_layers=num_layer,
)
print('building BertForSequenceClassification model')
model = BertForSequenceClassification(config)
# adapting huggingface BertForSequenceClassification for single unitest calling interface
class ModelAaptor(BertForSequenceClassification):
def forward(self, input_ids, labels):
"""
inputs: data, label
outputs: loss
"""
return super().forward(input_ids=input_ids, labels=labels)[0]
model = ModelAaptor(config)
if checkpoint and version.parse(transformers.__version__) >= version.parse("4.11.0"):
model.gradient_checkpointing_enable()
return model
trainloader = get_bert_data_loader(batch_size=2,