updated collective ops api (#1054)

This commit is contained in:
アマデウス 2022-06-02 12:52:27 +08:00 committed by GitHub
parent 51b9a49655
commit 2c42b230f3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -13,7 +13,6 @@ from colossalai.core import global_context as gpc
def all_gather(tensor: Tensor, def all_gather(tensor: Tensor,
dim: int, dim: int,
parallel_mode: ParallelMode, parallel_mode: ParallelMode,
on_cpu: bool = False,
async_op: bool = False) -> Tensor: async_op: bool = False) -> Tensor:
r"""Gathers all tensors from the parallel group and concatenates them in a r"""Gathers all tensors from the parallel group and concatenates them in a
specific dimension. specific dimension.
@ -26,7 +25,6 @@ def all_gather(tensor: Tensor,
tensor (:class:`torch.Tensor`): Tensor to be gathered. tensor (:class:`torch.Tensor`): Tensor to be gathered.
dim (int): The dimension concatenating in. dim (int): The dimension concatenating in.
parallel_mode (:class:`colossalai.context.ParallelMode`): Parallel group mode used in this communication. parallel_mode (:class:`colossalai.context.ParallelMode`): Parallel group mode used in this communication.
on_cpu (bool, optional): Whether to communicate with Gloo backend.
async_op (bool, optional): Whether operations are asynchronous. async_op (bool, optional): Whether operations are asynchronous.
Returns: Returns:
@ -43,7 +41,7 @@ def all_gather(tensor: Tensor,
shape[0] *= depth shape[0] *= depth
out = torch.empty(shape, dtype=tensor.dtype, device=tensor.device) out = torch.empty(shape, dtype=tensor.dtype, device=tensor.device)
temp = list(torch.chunk(out, depth, dim=0)) temp = list(torch.chunk(out, depth, dim=0))
group = gpc.get_cpu_group(parallel_mode) if on_cpu else gpc.get_group(parallel_mode) group = gpc.get_cpu_group(parallel_mode) if tensor.device.type == "cpu" else gpc.get_group(parallel_mode)
work = dist.all_gather(tensor_list=temp, work = dist.all_gather(tensor_list=temp,
tensor=tensor.transpose(0, dim).contiguous(), tensor=tensor.transpose(0, dim).contiguous(),
group=group, group=group,
@ -59,7 +57,6 @@ def reduce_scatter(tensor: Tensor,
dim: int, dim: int,
parallel_mode: ParallelMode, parallel_mode: ParallelMode,
op: ReduceOp = ReduceOp.SUM, op: ReduceOp = ReduceOp.SUM,
on_cpu: bool = False,
async_op: bool = False) -> Tensor: async_op: bool = False) -> Tensor:
r"""Reduces all tensors then scatters it in a specific dimension to all r"""Reduces all tensors then scatters it in a specific dimension to all
members in the parallel group. members in the parallel group.
@ -76,7 +73,6 @@ def reduce_scatter(tensor: Tensor,
should be included in [SUM, AVG, PRODUCT, MIN, MAX, BAND, BOR, BXOR]. should be included in [SUM, AVG, PRODUCT, MIN, MAX, BAND, BOR, BXOR].
More details about ReduceOp please refer to More details about ReduceOp please refer to
`ReduceOp <https://pytorch.org/docs/stable/distributed.html#torch.distributed.ReduceOp>`_. `ReduceOp <https://pytorch.org/docs/stable/distributed.html#torch.distributed.ReduceOp>`_.
on_cpu (bool, optional): Whether to communicate with Gloo backend.
async_op (bool, optional): Whether operations are asynchronous. async_op (bool, optional): Whether operations are asynchronous.
Returns: Returns:
@ -90,7 +86,7 @@ def reduce_scatter(tensor: Tensor,
else: else:
temp = list(map(lambda x: x.contiguous(), torch.chunk(tensor, depth, dim=dim))) temp = list(map(lambda x: x.contiguous(), torch.chunk(tensor, depth, dim=dim)))
out = torch.empty(temp[0].shape, dtype=tensor.dtype, device=tensor.device) out = torch.empty(temp[0].shape, dtype=tensor.dtype, device=tensor.device)
group = gpc.get_cpu_group(parallel_mode) if on_cpu else gpc.get_group(parallel_mode) group = gpc.get_cpu_group(parallel_mode) if tensor.device.type == "cpu" else gpc.get_group(parallel_mode)
work = dist.reduce_scatter(output=out, input_list=temp, op=op, group=group, async_op=async_op) work = dist.reduce_scatter(output=out, input_list=temp, op=op, group=group, async_op=async_op)
if async_op: if async_op:
return out, work return out, work
@ -101,7 +97,6 @@ def reduce_scatter(tensor: Tensor,
def all_reduce(tensor: Tensor, def all_reduce(tensor: Tensor,
parallel_mode: ParallelMode, parallel_mode: ParallelMode,
op: ReduceOp = ReduceOp.SUM, op: ReduceOp = ReduceOp.SUM,
on_cpu: bool = False,
async_op: bool = False) -> Tensor: async_op: bool = False) -> Tensor:
r"""Reduces the tensor data across whole parallel group in such a way that all get the final result. r"""Reduces the tensor data across whole parallel group in such a way that all get the final result.
@ -116,7 +111,6 @@ def all_reduce(tensor: Tensor,
should be included in [SUM, AVG, PRODUCT, MIN, MAX, BAND, BOR, BXOR]. should be included in [SUM, AVG, PRODUCT, MIN, MAX, BAND, BOR, BXOR].
More details about ReduceOp please refer to More details about ReduceOp please refer to
`ReduceOp <https://pytorch.org/docs/stable/distributed.html#torch.distributed.ReduceOp>`_. `ReduceOp <https://pytorch.org/docs/stable/distributed.html#torch.distributed.ReduceOp>`_.
on_cpu (bool, optional): Whether to communicate with Gloo backend.
async_op (bool, optional): Whether operations are asynchronous. async_op (bool, optional): Whether operations are asynchronous.
Returns: Returns:
@ -129,7 +123,7 @@ def all_reduce(tensor: Tensor,
work = None work = None
else: else:
out = tensor.contiguous() out = tensor.contiguous()
group = gpc.get_cpu_group(parallel_mode) if on_cpu else gpc.get_group(parallel_mode) group = gpc.get_cpu_group(parallel_mode) if tensor.device.type == "cpu" else gpc.get_group(parallel_mode)
work = dist.all_reduce(out, op=op, group=group, async_op=async_op) work = dist.all_reduce(out, op=op, group=group, async_op=async_op)
if async_op: if async_op:
return out, work return out, work
@ -137,7 +131,7 @@ def all_reduce(tensor: Tensor,
return out return out
def broadcast(tensor: Tensor, src: int, parallel_mode: ParallelMode, on_cpu: bool = False, async_op: bool = False): def broadcast(tensor: Tensor, src: int, parallel_mode: ParallelMode, async_op: bool = False):
r"""Broadcast tensors to whole parallel group. Tensor must have the same r"""Broadcast tensors to whole parallel group. Tensor must have the same
number of elements in all processes participating in the collective. number of elements in all processes participating in the collective.
@ -149,7 +143,6 @@ def broadcast(tensor: Tensor, src: int, parallel_mode: ParallelMode, on_cpu: boo
tensor (:class:`torch.Tensor`): Tensor to be broadcast. tensor (:class:`torch.Tensor`): Tensor to be broadcast.
src (int): Source rank. src (int): Source rank.
parallel_mode (:class:`colossalai.context.ParallelMode`): Parallel group mode used in this communication. parallel_mode (:class:`colossalai.context.ParallelMode`): Parallel group mode used in this communication.
on_cpu (bool, optional): Whether to communicate with Gloo backend.
async_op (bool, optional): Whether operations are asynchronous. async_op (bool, optional): Whether operations are asynchronous.
Returns: Returns:
@ -162,7 +155,7 @@ def broadcast(tensor: Tensor, src: int, parallel_mode: ParallelMode, on_cpu: boo
work = None work = None
else: else:
out = tensor.contiguous() out = tensor.contiguous()
group = gpc.get_cpu_group(parallel_mode) if on_cpu else gpc.get_group(parallel_mode) group = gpc.get_cpu_group(parallel_mode) if tensor.device.type == "cpu" else gpc.get_group(parallel_mode)
work = dist.broadcast(out, src=src, group=group, async_op=async_op) work = dist.broadcast(out, src=src, group=group, async_op=async_op)
if async_op: if async_op:
return out, work return out, work
@ -174,7 +167,6 @@ def reduce(tensor: Tensor,
dst: int, dst: int,
parallel_mode: ParallelMode, parallel_mode: ParallelMode,
op: ReduceOp = ReduceOp.SUM, op: ReduceOp = ReduceOp.SUM,
on_cpu: bool = False,
async_op: bool = False): async_op: bool = False):
r"""Reduce tensors across whole parallel group. Only the process with r"""Reduce tensors across whole parallel group. Only the process with
rank ``dst`` is going to receive the final result. rank ``dst`` is going to receive the final result.
@ -187,7 +179,6 @@ def reduce(tensor: Tensor,
tensor (:class:`torch.Tensor`): Tensor to be reduced. tensor (:class:`torch.Tensor`): Tensor to be reduced.
dst (int): Destination rank. dst (int): Destination rank.
parallel_mode (:class:`colossalai.context.ParallelMode`): Parallel group mode used in this communication. parallel_mode (:class:`colossalai.context.ParallelMode`): Parallel group mode used in this communication.
on_cpu (bool, optional): Whether to communicate with Gloo backend.
async_op (bool, optional): Whether operations are asynchronous. async_op (bool, optional): Whether operations are asynchronous.
Returns: Returns:
@ -200,7 +191,7 @@ def reduce(tensor: Tensor,
work = None work = None
else: else:
out = tensor.contiguous() out = tensor.contiguous()
group = gpc.get_cpu_group(parallel_mode) if on_cpu else gpc.get_group(parallel_mode) group = gpc.get_cpu_group(parallel_mode) if tensor.device.type == "cpu" else gpc.get_group(parallel_mode)
work = dist.reduce(out, dst=dst, op=op, group=group, async_op=async_op) work = dist.reduce(out, dst=dst, op=op, group=group, async_op=async_op)
if async_op: if async_op:
return out, work return out, work