[autoparallel] remove redundancy comm node (#1893)

This commit is contained in:
YuliangLiu0306
2022-11-15 10:53:41 +08:00
committed by GitHub
parent 9183e0dec5
commit 36c0f3ea5b
5 changed files with 23 additions and 20 deletions

View File

@@ -23,9 +23,7 @@ def _all_gather(tensor, comm_spec):
torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device)
for _ in range(comm_spec.device_mesh.mesh_shape[comm_spec.logical_process_axis])
]
tensor = tensor
group = process_group
dist.all_gather(tensor_list, tensor, group=group)
dist.all_gather(tensor_list, tensor, group=process_group)
output = torch.cat(tuple(tensor_list), comm_spec.gather_dim).contiguous()
return output
@@ -37,7 +35,6 @@ def _split(tensor, comm_spec):
process_groups_list = comm_spec.device_mesh.process_groups_dict[comm_spec.logical_process_axis]
for rank_list, _ in process_groups_list:
if dist.get_rank() in rank_list:
tensor = tensor
dim = comm_spec.shard_dim
length = tensor.shape[comm_spec.shard_dim] // len(rank_list)
start = length * rank_list.index(dist.get_rank())
@@ -69,7 +66,7 @@ def _all_to_all(tensor, comm_spec):
return output
def _all_reduce(tensor, comm_spec):
def _all_reduce(tensor, comm_spec, async_op=False):
'''
Implement all reduce operation on device mesh based on information provided by comm_spec.
'''
@@ -78,7 +75,7 @@ def _all_reduce(tensor, comm_spec):
if dist.get_rank() in rank_list:
if not tensor.is_contiguous():
tensor = tensor.contiguous()
dist.all_reduce(tensor, op=ReduceOp.SUM, group=process_group)
dist.all_reduce(tensor, op=ReduceOp.SUM, group=process_group, async_op=async_op)
return tensor