mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 04:24:47 +00:00
[shardformer] added embedding gradient check (#4124)
This commit is contained in:
@@ -18,20 +18,35 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
|
||||
org_loss.backward()
|
||||
shard_loss.backward()
|
||||
|
||||
# check grad equality
|
||||
assert torch.allclose(org_loss, shard_loss,
|
||||
atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}"
|
||||
|
||||
# check grad
|
||||
|
||||
if org_model.__class__.__name__ == 'BertModel':
|
||||
org_grad = org_model.encoder.layer[0].attention.self.query.weight.grad
|
||||
shard_grad = sharded_model.encoder.layer[0].attention.self.query.weight.grad
|
||||
bert = org_model
|
||||
sharded_bert = sharded_model
|
||||
else:
|
||||
org_grad = org_model.bert.encoder.layer[0].attention.self.query.weight.grad
|
||||
shard_grad = sharded_model.bert.encoder.layer[0].attention.self.query.weight.grad
|
||||
bert = org_model.bert
|
||||
sharded_bert = sharded_model.bert
|
||||
|
||||
# compare self attention grad
|
||||
org_grad = bert.encoder.layer[0].attention.self.query.weight.grad
|
||||
shard_grad = sharded_bert.encoder.layer[0].attention.self.query.weight.grad
|
||||
|
||||
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
|
||||
shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad)
|
||||
all_shard_grad = torch.cat(shard_grad_list, dim=0)
|
||||
assert torch.allclose(org_grad, all_shard_grad,
|
||||
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}"
|
||||
|
||||
assert torch.allclose(org_loss, shard_loss,
|
||||
atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}"
|
||||
# compare embedding grad
|
||||
org_grad = bert.embeddings.word_embeddings.weight.grad
|
||||
shard_grad = sharded_bert.embeddings.word_embeddings.weight.grad
|
||||
|
||||
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
|
||||
shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad)
|
||||
all_shard_grad = torch.cat(shard_grad_list, dim=0)
|
||||
assert torch.allclose(org_grad, all_shard_grad,
|
||||
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}"
|
||||
|
||||
|
Reference in New Issue
Block a user