fix format (#607)

This commit is contained in:
xyupeng 2022-04-01 13:31:06 +08:00 committed by binmakeswell
parent f2d2a1597a
commit d3d5bedc65

View File

@ -12,7 +12,6 @@ from functools import reduce
import operator import operator
from .utils import split_tensor_into_1d_equal_chunks, gather_split_1d_tensor from .utils import split_tensor_into_1d_equal_chunks, gather_split_1d_tensor
TensorShape = Union[torch.Size, List[int], Tuple[int]] TensorShape = Union[torch.Size, List[int], Tuple[int]]
@ -88,13 +87,11 @@ def _communicate(tensor_send_next=None,
if tensor_send_prev is not None or recv_prev: if tensor_send_prev is not None or recv_prev:
if prev_rank is None: if prev_rank is None:
prev_rank = gpc.get_prev_global_rank( prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE)
ParallelMode.PIPELINE)
if tensor_send_next is not None or recv_next: if tensor_send_next is not None or recv_next:
if next_rank is None: if next_rank is None:
next_rank = gpc.get_next_global_rank( next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE)
ParallelMode.PIPELINE)
if tensor_send_prev is not None: if tensor_send_prev is not None:
send_prev_split = _get_tensor_shape(tensor_send_prev.shape, scatter_gather_tensors)[1] send_prev_split = _get_tensor_shape(tensor_send_prev.shape, scatter_gather_tensors)[1]
@ -183,9 +180,7 @@ def send_forward(output_tensor, next_rank=None, scatter_gather_tensors=False):
next_rank (int, optional): The rank of the recipient of the tensor. next_rank (int, optional): The rank of the recipient of the tensor.
""" """
if not gpc.is_pipeline_last_stage(): if not gpc.is_pipeline_last_stage():
_communicate(tensor_send_next=output_tensor, _communicate(tensor_send_next=output_tensor, next_rank=next_rank, scatter_gather_tensors=scatter_gather_tensors)
next_rank=next_rank,
scatter_gather_tensors=scatter_gather_tensors)
def send_backward(input_tensor_grad, prev_rank=None, scatter_gather_tensors=False): def send_backward(input_tensor_grad, prev_rank=None, scatter_gather_tensors=False):
@ -338,15 +333,14 @@ def send_forward_backward_recv_forward_backward(output_tensor,
Returns: Returns:
Tuple(Tensor, Tensor): (the input tensor, the input gradient tensor) Tuple(Tensor, Tensor): (the input tensor, the input gradient tensor)
""" """
input_tensor, output_tensor_grad = _communicate( input_tensor, output_tensor_grad = _communicate(tensor_send_next=output_tensor,
tensor_send_next=output_tensor, tensor_send_prev=input_tensor_grad,
tensor_send_prev=input_tensor_grad, recv_prev=recv_prev,
recv_prev=recv_prev, recv_next=recv_next,
recv_next=recv_next, recv_prev_shape=input_tensor_shape,
recv_prev_shape=input_tensor_shape, recv_next_shape=output_grad_shape,
recv_next_shape=output_grad_shape, prev_rank=prev_rank,
prev_rank=prev_rank, next_rank=next_rank,
next_rank=next_rank, dtype=dtype,
dtype=dtype, scatter_gather_tensors=scatter_gather_tensors)
scatter_gather_tensors=scatter_gather_tensors)
return input_tensor, output_tensor_grad return input_tensor, output_tensor_grad