mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 19:13:01 +00:00
[shardformer]fix, test gpt2 for AMP+TP (#4403)
* [shardformer] gpt2 tests fix [shardformer] test all optimizations (#4399) [shardformer] test all optimizations [shardformer] test all optimizations [shardformer] test all optimizations [shardformer] gpt2 tests fix * [shardformer] gpt2 tests fix
This commit is contained in:
committed by
Hongxin Liu
parent
7596e9ae08
commit
21e0a42fd1
@@ -210,7 +210,7 @@ def check_weight(org_model: Module,
|
||||
|
||||
if is_distributed_tensor(sharded_weight) or is_customized_distributed_tensor(sharded_weight):
|
||||
sharded_weight_list = [
|
||||
torch.zeros([*sharded_weight.shape]).to('cuda') for _ in range(dist.get_world_size(tp_group))
|
||||
torch.zeros_like(sharded_weight).to('cuda') for _ in range(dist.get_world_size(tp_group))
|
||||
]
|
||||
dist.all_gather(sharded_weight_list, sharded_weight, tp_group)
|
||||
sharded_weight = torch.cat(sharded_weight_list, dim=dim)
|
||||
@@ -219,7 +219,7 @@ def check_weight(org_model: Module,
|
||||
print(f"'{suffix}' weight: {org_weight}, {sharded_weight}")
|
||||
|
||||
assert torch.allclose(org_weight.float(), sharded_weight.float(), atol=atol, rtol=rtol), \
|
||||
f"shard model weight is not equal to origin model weight\n{org_weight}\n{sharded_weight}"
|
||||
f"shard model weight {suffix} is not equal to origin model weight\n{org_weight}\n{sharded_weight}"
|
||||
|
||||
|
||||
def check_grad(org_model: Module,
|
||||
@@ -236,9 +236,7 @@ def check_grad(org_model: Module,
|
||||
shard_weight = getattr_(sharded_model, suffix).weight
|
||||
|
||||
if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
|
||||
shard_grad_list = [
|
||||
torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(dist.get_world_size(tp_group))
|
||||
]
|
||||
shard_grad_list = [torch.zeros_like(shard_grad).to('cuda') for _ in range(dist.get_world_size(tp_group))]
|
||||
dist.all_gather(shard_grad_list, shard_grad, tp_group)
|
||||
shard_grad = torch.cat(shard_grad_list, dim=dim)
|
||||
|
||||
|
Reference in New Issue
Block a user