mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-30 01:02:00 +00:00
adapting bert unitest interface
This commit is contained in:
parent
d984dff88c
commit
02c12f128b
@ -48,9 +48,21 @@ def get_training_components():
|
|||||||
num_hidden_layers=num_layer,
|
num_hidden_layers=num_layer,
|
||||||
)
|
)
|
||||||
print('building BertForSequenceClassification model')
|
print('building BertForSequenceClassification model')
|
||||||
model = BertForSequenceClassification(config)
|
|
||||||
|
# adapting huggingface BertForSequenceClassification for single unitest calling interface
|
||||||
|
class ModelAaptor(BertForSequenceClassification):
|
||||||
|
|
||||||
|
def forward(self, input_ids, labels):
|
||||||
|
"""
|
||||||
|
inputs: data, label
|
||||||
|
outputs: loss
|
||||||
|
"""
|
||||||
|
return super().forward(input_ids=input_ids, labels=labels)[0]
|
||||||
|
|
||||||
|
model = ModelAaptor(config)
|
||||||
if checkpoint and version.parse(transformers.__version__) >= version.parse("4.11.0"):
|
if checkpoint and version.parse(transformers.__version__) >= version.parse("4.11.0"):
|
||||||
model.gradient_checkpointing_enable()
|
model.gradient_checkpointing_enable()
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
trainloader = get_bert_data_loader(batch_size=2,
|
trainloader = get_bert_data_loader(batch_size=2,
|
||||||
|
@ -31,11 +31,11 @@ def run_fwd_bwd(model, data, label, criterion, enable_autocast=False):
|
|||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
|
|
||||||
def run_bert_fwd_bwd(model, data, label, enable_autocast=False):
|
# with no criterion
|
||||||
|
def run_fwd_bwd_no_criterion(model, data, label, enable_autocast=False):
|
||||||
model.train()
|
model.train()
|
||||||
with torch.cuda.amp.autocast(enabled=enable_autocast):
|
with torch.cuda.amp.autocast(enabled=enable_autocast):
|
||||||
output = model(input_ids=data, labels=label)
|
loss = model(data, label)
|
||||||
loss = output[0]
|
|
||||||
if isinstance(model, ShardedModelV2):
|
if isinstance(model, ShardedModelV2):
|
||||||
model.backward(loss)
|
model.backward(loss)
|
||||||
else:
|
else:
|
||||||
@ -60,8 +60,8 @@ def run_dist(rank, world_size, port):
|
|||||||
|
|
||||||
if model_name == 'bert':
|
if model_name == 'bert':
|
||||||
data, label = data.cuda(), label.cuda()
|
data, label = data.cuda(), label.cuda()
|
||||||
run_bert_fwd_bwd(model, data, label, False)
|
run_fwd_bwd_no_criterion(model, data, label, False)
|
||||||
run_bert_fwd_bwd(zero_model, data, label, False)
|
run_fwd_bwd_no_criterion(zero_model, data, label, False)
|
||||||
else:
|
else:
|
||||||
data, label = data.half().cuda(), label.cuda()
|
data, label = data.half().cuda(), label.cuda()
|
||||||
run_fwd_bwd(model, data, label, criterion, False)
|
run_fwd_bwd(model, data, label, criterion, False)
|
||||||
|
Loading…
Reference in New Issue
Block a user