[autoparallel] add experimental permute handler (#2029)

This commit is contained in:
YuliangLiu0306
2022-11-27 20:26:52 +08:00
committed by GitHub
parent 95c4532fff
commit 81330b0352
11 changed files with 657 additions and 37 deletions

View File

@@ -23,6 +23,8 @@ def _all_gather(tensor, comm_spec):
torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device)
for _ in range(comm_spec.device_mesh.mesh_shape[comm_spec.logical_process_axis])
]
# without this contiguous operation, the all gather may get some unexpected results.
tensor = tensor.contiguous()
dist.all_gather(tensor_list, tensor, group=process_group)
output = torch.cat(tuple(tensor_list), comm_spec.gather_dim).contiguous()
return output