mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 19:40:28 +00:00
[shardformer] added embedding gradient check (#4124)
This commit is contained in:
@@ -18,20 +18,36 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
|
||||
org_loss.backward()
|
||||
shard_loss.backward()
|
||||
|
||||
# check grad equality
|
||||
assert torch.allclose(org_loss, shard_loss,
|
||||
atol=1e-5), f"shard model loss is not equal to origin model loss\n{org_loss}\n{shard_loss}"
|
||||
|
||||
# unwrap model
|
||||
if org_model.__class__.__name__ == 'GPT2Model':
|
||||
org_grad = org_model.h[0].mlp.c_fc.weight.grad
|
||||
shard_grad = sharded_model.h[0].mlp.c_fc.weight.grad
|
||||
org_model = org_model
|
||||
sharded_model = sharded_model
|
||||
else:
|
||||
org_grad = org_model.transformer.h[0].mlp.c_fc.weight.grad
|
||||
shard_grad = sharded_model.transformer.h[0].mlp.c_fc.weight.grad
|
||||
org_model = org_model.transformer
|
||||
sharded_model = sharded_model.transformer
|
||||
|
||||
# check mlp grad
|
||||
org_grad = org_model.h[0].mlp.c_fc.weight.grad
|
||||
shard_grad = sharded_model.h[0].mlp.c_fc.weight.grad
|
||||
|
||||
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
|
||||
shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad)
|
||||
all_shard_grad = torch.cat(shard_grad_list, dim=1)
|
||||
|
||||
assert torch.allclose(org_loss, shard_loss,
|
||||
atol=1e-5), f"shard model loss is not equal to origin model loss\n{org_loss}\n{shard_loss}"
|
||||
assert torch.allclose(
|
||||
org_grad, all_shard_grad,
|
||||
atol=1e-5), f"shard model grad is not equal to origin model grad\n{org_grad}\n{all_shard_grad}"
|
||||
|
||||
# check embedding weights
|
||||
org_grad = org_model.wte.weight.grad
|
||||
shard_grad = sharded_model.wte.weight.grad
|
||||
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
|
||||
shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad)
|
||||
all_shard_grad = torch.cat(shard_grad_list, dim=0)
|
||||
|
||||
assert torch.allclose(
|
||||
org_grad, all_shard_grad,
|
||||
atol=1e-5), f"shard model grad is not equal to origin model grad\n{org_grad}\n{all_shard_grad}"
|
||||
|
Reference in New Issue
Block a user