adapting bert unitest interface

This commit is contained in:
jiaruifang 2022-03-09 11:26:10 +08:00 committed by ver217
parent d984dff88c
commit 02c12f128b
2 changed files with 18 additions and 6 deletions

View File

@ -48,9 +48,21 @@ def get_training_components():
num_hidden_layers=num_layer, num_hidden_layers=num_layer,
) )
print('building BertForSequenceClassification model') print('building BertForSequenceClassification model')
model = BertForSequenceClassification(config)
# adapting huggingface BertForSequenceClassification for single unitest calling interface
class ModelAaptor(BertForSequenceClassification):
def forward(self, input_ids, labels):
"""
inputs: data, label
outputs: loss
"""
return super().forward(input_ids=input_ids, labels=labels)[0]
model = ModelAaptor(config)
if checkpoint and version.parse(transformers.__version__) >= version.parse("4.11.0"): if checkpoint and version.parse(transformers.__version__) >= version.parse("4.11.0"):
model.gradient_checkpointing_enable() model.gradient_checkpointing_enable()
return model return model
trainloader = get_bert_data_loader(batch_size=2, trainloader = get_bert_data_loader(batch_size=2,

View File

@ -31,11 +31,11 @@ def run_fwd_bwd(model, data, label, criterion, enable_autocast=False):
loss.backward() loss.backward()
def run_bert_fwd_bwd(model, data, label, enable_autocast=False): # with no criterion
def run_fwd_bwd_no_criterion(model, data, label, enable_autocast=False):
model.train() model.train()
with torch.cuda.amp.autocast(enabled=enable_autocast): with torch.cuda.amp.autocast(enabled=enable_autocast):
output = model(input_ids=data, labels=label) loss = model(data, label)
loss = output[0]
if isinstance(model, ShardedModelV2): if isinstance(model, ShardedModelV2):
model.backward(loss) model.backward(loss)
else: else:
@ -60,8 +60,8 @@ def run_dist(rank, world_size, port):
if model_name == 'bert': if model_name == 'bert':
data, label = data.cuda(), label.cuda() data, label = data.cuda(), label.cuda()
run_bert_fwd_bwd(model, data, label, False) run_fwd_bwd_no_criterion(model, data, label, False)
run_bert_fwd_bwd(zero_model, data, label, False) run_fwd_bwd_no_criterion(zero_model, data, label, False)
else: else:
data, label = data.half().cuda(), label.cuda() data, label = data.half().cuda(), label.cuda()
run_fwd_bwd(model, data, label, criterion, False) run_fwd_bwd(model, data, label, criterion, False)