[shardformer]fix, test gpt2 for AMP+TP (#4403)

* [shardformer] gpt2 tests fix

[shardformer] test all optimizations (#4399)

[shardformer] test all optimizations

[shardformer] test all optimizations

[shardformer] test all optimizations

[shardformer] gpt2 tests fix

* [shardformer] gpt2 tests fix
This commit is contained in:
flybird11111
2023-08-11 11:44:23 +08:00
committed by Hongxin Liu
parent 7596e9ae08
commit 21e0a42fd1
2 changed files with 6 additions and 10 deletions

View File

@@ -210,7 +210,7 @@ def check_weight(org_model: Module,
if is_distributed_tensor(sharded_weight) or is_customized_distributed_tensor(sharded_weight):
sharded_weight_list = [
torch.zeros([*sharded_weight.shape]).to('cuda') for _ in range(dist.get_world_size(tp_group))
torch.zeros_like(sharded_weight).to('cuda') for _ in range(dist.get_world_size(tp_group))
]
dist.all_gather(sharded_weight_list, sharded_weight, tp_group)
sharded_weight = torch.cat(sharded_weight_list, dim=dim)
@@ -219,7 +219,7 @@ def check_weight(org_model: Module,
print(f"'{suffix}' weight: {org_weight}, {sharded_weight}")
assert torch.allclose(org_weight.float(), sharded_weight.float(), atol=atol, rtol=rtol), \
f"shard model weight is not equal to origin model weight\n{org_weight}\n{sharded_weight}"
f"shard model weight {suffix} is not equal to origin model weight\n{org_weight}\n{sharded_weight}"
def check_grad(org_model: Module,
@@ -236,9 +236,7 @@ def check_grad(org_model: Module,
shard_weight = getattr_(sharded_model, suffix).weight
if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
shard_grad_list = [
torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(dist.get_world_size(tp_group))
]
shard_grad_list = [torch.zeros_like(shard_grad).to('cuda') for _ in range(dist.get_world_size(tp_group))]
dist.all_gather(shard_grad_list, shard_grad, tp_group)
shard_grad = torch.cat(shard_grad_list, dim=dim)