mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-10 05:20:33 +00:00
[autoparallel] update CommSpec to CommActions (#1768)
* [autoparallel] update CommSpec to CommActions * polish code
This commit is contained in:
@@ -41,7 +41,7 @@ def _split(tensor, comm_spec):
|
||||
dim = comm_spec.shard_dim
|
||||
length = tensor.shape[comm_spec.shard_dim] // len(rank_list)
|
||||
start = length * rank_list.index(dist.get_rank())
|
||||
output = torch.narrow(tensor, dim, start, length)
|
||||
output = torch.narrow(tensor, dim, start, length).contiguous()
|
||||
return output
|
||||
|
||||
|
||||
@@ -76,6 +76,8 @@ def _all_reduce(tensor, comm_spec):
|
||||
process_groups_list = comm_spec.device_mesh.process_groups_dict[comm_spec.logical_process_axis]
|
||||
for rank_list, process_group in process_groups_list:
|
||||
if dist.get_rank() in rank_list:
|
||||
if not tensor.is_contiguous():
|
||||
tensor = tensor.contiguous()
|
||||
dist.all_reduce(tensor, op=ReduceOp.SUM, group=process_group)
|
||||
return tensor
|
||||
|
||||
|
Reference in New Issue
Block a user