diff --git a/colossalai/communication/p2p.py b/colossalai/communication/p2p.py index 1f20fca4f..d28d14016 100644 --- a/colossalai/communication/p2p.py +++ b/colossalai/communication/p2p.py @@ -1,16 +1,18 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- +import operator +from functools import reduce from typing import List, Tuple, Union + import torch import torch.distributed as dist from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc from colossalai.utils import get_current_device -from functools import reduce -import operator -from .utils import split_tensor_into_1d_equal_chunks, gather_split_1d_tensor + +from .utils import gather_split_1d_tensor, split_tensor_into_1d_equal_chunks TensorShape = Union[torch.Size, List[int], Tuple[int]] @@ -260,7 +262,7 @@ def send_forward_recv_backward(output_tensor, next_rank=None, dtype=torch.float, scatter_gather_tensors=False) -> Union[torch.Tensor, List[torch.Tensor]]: - """Batched communication operation. Sends the input tensor to the + """Batched communication operation. Sends the input tensor to the next stage in pipeline, while receives the gradient tensor from the next stage in pipeline as the input gradient tensor of this stage. @@ -319,7 +321,7 @@ def send_forward_recv_forward(output_tensor, next_rank=None, dtype=torch.float, scatter_gather_tensors=False) -> Union[torch.Tensor, List[torch.Tensor]]: - """Batched communication operation. Sends the input tensor to the + """Batched communication operation. Sends the input tensor to the next stage in pipeline, while receives the output tensor from the previous stage in pipeline as the input of this stage.