adapting bert unitest interface

This commit is contained in:
jiaruifang
2022-03-09 11:26:10 +08:00
committed by ver217
parent d984dff88c
commit 02c12f128b
2 changed files with 18 additions and 6 deletions

View File

@@ -31,11 +31,11 @@ def run_fwd_bwd(model, data, label, criterion, enable_autocast=False):
loss.backward()
def run_bert_fwd_bwd(model, data, label, enable_autocast=False):
# with no criterion
def run_fwd_bwd_no_criterion(model, data, label, enable_autocast=False):
model.train()
with torch.cuda.amp.autocast(enabled=enable_autocast):
output = model(input_ids=data, labels=label)
loss = output[0]
loss = model(data, label)
if isinstance(model, ShardedModelV2):
model.backward(loss)
else:
@@ -60,8 +60,8 @@ def run_dist(rank, world_size, port):
if model_name == 'bert':
data, label = data.cuda(), label.cuda()
run_bert_fwd_bwd(model, data, label, False)
run_bert_fwd_bwd(zero_model, data, label, False)
run_fwd_bwd_no_criterion(model, data, label, False)
run_fwd_bwd_no_criterion(zero_model, data, label, False)
else:
data, label = data.half().cuda(), label.cuda()
run_fwd_bwd(model, data, label, criterion, False)