mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 04:50:17 +00:00
[shardformer] added embedding gradient check (#4124)
This commit is contained in:
@@ -28,7 +28,10 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
|
||||
org_loss.backward()
|
||||
shard_loss.backward()
|
||||
|
||||
# check grad
|
||||
assert torch.allclose(org_loss, shard_loss,
|
||||
atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}"
|
||||
|
||||
# unwrap model
|
||||
if hasattr(org_model, 'model'):
|
||||
opt_model = org_model.model
|
||||
shard_opt_model = sharded_model.model
|
||||
@@ -36,16 +39,23 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
|
||||
opt_model = org_model
|
||||
shard_opt_model = sharded_model
|
||||
|
||||
# check attention grad
|
||||
org_grad = opt_model.decoder.layers[0].self_attn.q_proj.weight.grad
|
||||
shard_grad = shard_opt_model.decoder.layers[0].self_attn.q_proj.weight.grad
|
||||
|
||||
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(4)]
|
||||
shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad)
|
||||
torch.distributed.all_gather(shard_grad_list, shard_grad)
|
||||
all_shard_grad = torch.cat(shard_grad_list, dim=0)
|
||||
assert torch.allclose(org_loss, shard_loss,
|
||||
atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}"
|
||||
assert torch.allclose(org_grad, all_shard_grad,
|
||||
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{shard_grad}"
|
||||
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}"
|
||||
|
||||
# check embedding grad
|
||||
org_grad = opt_model.decoder.embed_tokens.weight.grad
|
||||
shard_grad = shard_opt_model.decoder.embed_tokens.weight.grad
|
||||
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(4)]
|
||||
torch.distributed.all_gather(shard_grad_list, shard_grad)
|
||||
all_shard_grad = torch.cat(shard_grad_list, dim=0)
|
||||
assert torch.allclose(org_grad, all_shard_grad,
|
||||
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}"
|
||||
|
||||
|
||||
def check_OPTModel(rank, world_size, port):
|
||||
@@ -65,3 +75,7 @@ def check_OPTModel(rank, world_size, port):
|
||||
@clear_cache_before_run()
|
||||
def test_OPTModel():
|
||||
spawn(check_OPTModel, 4)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_OPTModel()
|
||||
|
Reference in New Issue
Block a user