[autoparallel] update CommSpec to CommActions (#1768)

* [autoparallel] update CommSpec to CommActions

* polish code
This commit is contained in:
YuliangLiu0306
2022-10-28 09:57:43 +08:00
committed by GitHub
parent 16b0abf94f
commit b0f7c8bde8
7 changed files with 267 additions and 122 deletions

View File

@@ -41,7 +41,7 @@ def _split(tensor, comm_spec):
dim = comm_spec.shard_dim
length = tensor.shape[comm_spec.shard_dim] // len(rank_list)
start = length * rank_list.index(dist.get_rank())
output = torch.narrow(tensor, dim, start, length)
output = torch.narrow(tensor, dim, start, length).contiguous()
return output
@@ -76,6 +76,8 @@ def _all_reduce(tensor, comm_spec):
process_groups_list = comm_spec.device_mesh.process_groups_dict[comm_spec.logical_process_axis]
for rank_list, process_group in process_groups_list:
if dist.get_rank() in rank_list:
if not tensor.is_contiguous():
tensor = tensor.contiguous()
dist.all_reduce(tensor, op=ReduceOp.SUM, group=process_group)
return tensor