[shardformer] tests for 3d parallel (#4493)

This commit is contained in:
Jianghai
2023-08-23 15:05:24 +08:00
committed by GitHub
parent 59e252ecdb
commit e04436a82a
10 changed files with 324 additions and 5 deletions

View File

@@ -245,7 +245,6 @@ def check_grad(org_model: Module,
org_grad = getattr_(org_model, suffix).weight.grad
shard_grad = getattr_(sharded_model, suffix).weight.grad
shard_weight = getattr_(sharded_model, suffix).weight
if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
shard_grad_list = [torch.zeros_like(shard_grad).to('cuda') for _ in range(dist.get_world_size(tp_group))]
dist.all_gather(shard_grad_list, shard_grad, tp_group)