From d3d5bedc65efba85361cd2b2b21e447f50cceb62 Mon Sep 17 00:00:00 2001 From: xyupeng <99191637+xyupeng@users.noreply.github.com> Date: Fri, 1 Apr 2022 13:31:06 +0800 Subject: [PATCH] fix format (#607) --- colossalai/communication/p2p.py | 32 +++++++++++++------------------- 1 file changed, 13 insertions(+), 19 deletions(-) diff --git a/colossalai/communication/p2p.py b/colossalai/communication/p2p.py index 76fc251ab..220f04861 100644 --- a/colossalai/communication/p2p.py +++ b/colossalai/communication/p2p.py @@ -12,7 +12,6 @@ from functools import reduce import operator from .utils import split_tensor_into_1d_equal_chunks, gather_split_1d_tensor - TensorShape = Union[torch.Size, List[int], Tuple[int]] @@ -88,13 +87,11 @@ def _communicate(tensor_send_next=None, if tensor_send_prev is not None or recv_prev: if prev_rank is None: - prev_rank = gpc.get_prev_global_rank( - ParallelMode.PIPELINE) + prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE) if tensor_send_next is not None or recv_next: if next_rank is None: - next_rank = gpc.get_next_global_rank( - ParallelMode.PIPELINE) + next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE) if tensor_send_prev is not None: send_prev_split = _get_tensor_shape(tensor_send_prev.shape, scatter_gather_tensors)[1] @@ -183,9 +180,7 @@ def send_forward(output_tensor, next_rank=None, scatter_gather_tensors=False): next_rank (int, optional): The rank of the recipient of the tensor. """ if not gpc.is_pipeline_last_stage(): - _communicate(tensor_send_next=output_tensor, - next_rank=next_rank, - scatter_gather_tensors=scatter_gather_tensors) + _communicate(tensor_send_next=output_tensor, next_rank=next_rank, scatter_gather_tensors=scatter_gather_tensors) def send_backward(input_tensor_grad, prev_rank=None, scatter_gather_tensors=False): @@ -338,15 +333,14 @@ def send_forward_backward_recv_forward_backward(output_tensor, Returns: Tuple(Tensor, Tensor): (the input tensor, the input gradient tensor) """ - input_tensor, output_tensor_grad = _communicate( - tensor_send_next=output_tensor, - tensor_send_prev=input_tensor_grad, - recv_prev=recv_prev, - recv_next=recv_next, - recv_prev_shape=input_tensor_shape, - recv_next_shape=output_grad_shape, - prev_rank=prev_rank, - next_rank=next_rank, - dtype=dtype, - scatter_gather_tensors=scatter_gather_tensors) + input_tensor, output_tensor_grad = _communicate(tensor_send_next=output_tensor, + tensor_send_prev=input_tensor_grad, + recv_prev=recv_prev, + recv_next=recv_next, + recv_prev_shape=input_tensor_shape, + recv_next_shape=output_grad_shape, + prev_rank=prev_rank, + next_rank=next_rank, + dtype=dtype, + scatter_gather_tensors=scatter_gather_tensors) return input_tensor, output_tensor_grad