mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2026-04-26 09:42:27 +00:00
add bert for unitest and sharded model is not able to pass the bert case
This commit is contained in:
@@ -75,7 +75,7 @@ def check_grads_padding(model, zero_model, loose=False):
|
||||
if zero_grad.size(0) > grad.size(0):
|
||||
zero_grad = zero_grad[:grad.size(0)]
|
||||
assert grad.dtype == zero_grad.dtype
|
||||
assert allclose(grad, zero_grad, loose=loose), f'{grad} vs {zero_grad}'
|
||||
assert allclose(grad, zero_grad, loose=loose), f'diff: {grad - zero_grad}'
|
||||
|
||||
|
||||
def check_params_padding(model, zero_model, loose=False):
|
||||
|
||||
@@ -31,14 +31,25 @@ def run_fwd_bwd(model, data, label, criterion, enable_autocast=False):
|
||||
loss.backward()
|
||||
|
||||
|
||||
def run_bert_fwd_bwd(model, data, label, enable_autocast=False):
|
||||
model.train()
|
||||
with torch.cuda.amp.autocast(enabled=enable_autocast):
|
||||
output = model(input_ids=data, labels=label)
|
||||
loss = output[0]
|
||||
if isinstance(model, ShardedModelV2):
|
||||
model.backward(loss)
|
||||
else:
|
||||
loss.backward()
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
test_models = ['repeated_computed_layers', 'resnet18']
|
||||
test_models = ['repeated_computed_layers', 'resnet18', 'bert']
|
||||
shard_strategy = TensorShardStrategy()
|
||||
for model_name in test_models:
|
||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||
shard_strategy = TensorShardStrategy()
|
||||
model, train_dataloader, test_dataloader, optimizer, criterion = get_components_func()
|
||||
model = model().half().cuda()
|
||||
model = model(checkpoint=True).half().cuda()
|
||||
zero_model = ShardedModelV2(copy.deepcopy(model), shard_strategy)
|
||||
if dist.get_world_size() > 1:
|
||||
model = DDP(model)
|
||||
@@ -46,9 +57,16 @@ def run_dist(rank, world_size, port):
|
||||
for i, (data, label) in enumerate(train_dataloader):
|
||||
if i > 2:
|
||||
break
|
||||
data, label = data.half().cuda(), label.cuda()
|
||||
run_fwd_bwd(model, data, label, criterion, False)
|
||||
run_fwd_bwd(zero_model, data, label, criterion, False)
|
||||
|
||||
if model_name == 'bert':
|
||||
data, label = data.cuda(), label.cuda()
|
||||
run_bert_fwd_bwd(model, data, label, False)
|
||||
run_bert_fwd_bwd(zero_model, data, label, False)
|
||||
else:
|
||||
data, label = data.half().cuda(), label.cuda()
|
||||
run_fwd_bwd(model, data, label, criterion, False)
|
||||
run_fwd_bwd(zero_model, data, label, criterion, False)
|
||||
|
||||
if dist.get_world_size() > 1:
|
||||
check_grads_padding(model, zero_model, loose=True)
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user