update examples and sphnix docs for the new api (#63)

This commit is contained in:
Frank Lee
2021-12-13 22:07:01 +08:00
committed by GitHub
parent 7d3711058f
commit 35813ed3c4
124 changed files with 1251 additions and 1462 deletions

View File

@@ -18,11 +18,11 @@ def all_gather(tensor: Tensor, dim: int,
:param tensor: Tensor to be gathered
:param dim: The dimension concatenating in
:param parallel_mode: Parallel group mode used in this communication
:type tensor: Tensor
:type tensor: :class:`torch.Tensor`
:type dim: int
:type parallel_mode: ParallelMode
:type parallel_mode: :class:`colossalai.context.ParallelMode`
:return: The tensor generated by all-gather
:rtype: Tensor
:rtype: :class:`torch.Tensor`
"""
depth = gpc.get_world_size(parallel_mode)
temp = tensor.clone()
@@ -54,11 +54,11 @@ def reduce_scatter(tensor: Tensor, dim: int,
:param tensor: Tensor to be reduced and scattered
:param dim: The dimension scattering in
:param parallel_mode: Parallel group mode used in this communication
:type tensor: Tensor
:type tensor: :class:`torch.Tensor`
:type dim: int
:type parallel_mode: ParallelMode
:type parallel_mode: :class:`colossalai.context.ParallelMode`
:return: The tensor generated by reduce-scatter
:rtype: Tensor
:rtype: :class:`Tensor`
"""
depth = gpc.get_world_size(parallel_mode)
# temp = list(torch.chunk(tensor, depth, dim=dim))