mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-16 22:52:25 +00:00
adapting bert unitest interface
This commit is contained in:
@@ -48,9 +48,21 @@ def get_training_components():
|
||||
num_hidden_layers=num_layer,
|
||||
)
|
||||
print('building BertForSequenceClassification model')
|
||||
model = BertForSequenceClassification(config)
|
||||
|
||||
# adapting huggingface BertForSequenceClassification for single unitest calling interface
|
||||
class ModelAaptor(BertForSequenceClassification):
|
||||
|
||||
def forward(self, input_ids, labels):
|
||||
"""
|
||||
inputs: data, label
|
||||
outputs: loss
|
||||
"""
|
||||
return super().forward(input_ids=input_ids, labels=labels)[0]
|
||||
|
||||
model = ModelAaptor(config)
|
||||
if checkpoint and version.parse(transformers.__version__) >= version.parse("4.11.0"):
|
||||
model.gradient_checkpointing_enable()
|
||||
|
||||
return model
|
||||
|
||||
trainloader = get_bert_data_loader(batch_size=2,
|
||||
|
Reference in New Issue
Block a user