mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-18 16:00:49 +00:00
@@ -32,7 +32,8 @@ def _reduce(input_, parallel_mode):
|
||||
# skip if only one rank involved
|
||||
if gpc.get_world_size(parallel_mode) == 1:
|
||||
return input_
|
||||
dist.all_reduce(input_, group=gpc.get_group(parallel_mode))
|
||||
group = gpc.get_cpu_group(parallel_mode) if input_.device.type == "cpu" else gpc.get_group(parallel_mode)
|
||||
dist.all_reduce(input_, group=group)
|
||||
|
||||
return input_
|
||||
|
||||
@@ -66,7 +67,8 @@ def _gather(input_, parallel_mode, dim=-1):
|
||||
rank = gpc.get_local_rank(parallel_mode)
|
||||
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
|
||||
tensor_list[rank] = input_
|
||||
torch.distributed.all_gather(tensor_list, input_, group=gpc.get_group(parallel_mode))
|
||||
group = gpc.get_cpu_group(parallel_mode) if input_.device.type == "cpu" else gpc.get_group(parallel_mode)
|
||||
torch.distributed.all_gather(tensor_list, input_, group=group)
|
||||
|
||||
# concat
|
||||
output = torch.cat(tensor_list, dim=dim).contiguous()
|
||||
|
Reference in New Issue
Block a user