mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2026-04-26 01:35:21 +00:00
adapting bert unitest interface
This commit is contained in:
@@ -31,11 +31,11 @@ def run_fwd_bwd(model, data, label, criterion, enable_autocast=False):
|
||||
loss.backward()
|
||||
|
||||
|
||||
def run_bert_fwd_bwd(model, data, label, enable_autocast=False):
|
||||
# with no criterion
|
||||
def run_fwd_bwd_no_criterion(model, data, label, enable_autocast=False):
|
||||
model.train()
|
||||
with torch.cuda.amp.autocast(enabled=enable_autocast):
|
||||
output = model(input_ids=data, labels=label)
|
||||
loss = output[0]
|
||||
loss = model(data, label)
|
||||
if isinstance(model, ShardedModelV2):
|
||||
model.backward(loss)
|
||||
else:
|
||||
@@ -60,8 +60,8 @@ def run_dist(rank, world_size, port):
|
||||
|
||||
if model_name == 'bert':
|
||||
data, label = data.cuda(), label.cuda()
|
||||
run_bert_fwd_bwd(model, data, label, False)
|
||||
run_bert_fwd_bwd(zero_model, data, label, False)
|
||||
run_fwd_bwd_no_criterion(model, data, label, False)
|
||||
run_fwd_bwd_no_criterion(zero_model, data, label, False)
|
||||
else:
|
||||
data, label = data.half().cuda(), label.cuda()
|
||||
run_fwd_bwd(model, data, label, criterion, False)
|
||||
|
||||
Reference in New Issue
Block a user