mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-03 10:36:47 +00:00
fix format (#607)
This commit is contained in:
parent
f2d2a1597a
commit
d3d5bedc65
@ -12,7 +12,6 @@ from functools import reduce
|
||||
import operator
|
||||
from .utils import split_tensor_into_1d_equal_chunks, gather_split_1d_tensor
|
||||
|
||||
|
||||
TensorShape = Union[torch.Size, List[int], Tuple[int]]
|
||||
|
||||
|
||||
@ -88,13 +87,11 @@ def _communicate(tensor_send_next=None,
|
||||
|
||||
if tensor_send_prev is not None or recv_prev:
|
||||
if prev_rank is None:
|
||||
prev_rank = gpc.get_prev_global_rank(
|
||||
ParallelMode.PIPELINE)
|
||||
prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE)
|
||||
|
||||
if tensor_send_next is not None or recv_next:
|
||||
if next_rank is None:
|
||||
next_rank = gpc.get_next_global_rank(
|
||||
ParallelMode.PIPELINE)
|
||||
next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE)
|
||||
|
||||
if tensor_send_prev is not None:
|
||||
send_prev_split = _get_tensor_shape(tensor_send_prev.shape, scatter_gather_tensors)[1]
|
||||
@ -183,9 +180,7 @@ def send_forward(output_tensor, next_rank=None, scatter_gather_tensors=False):
|
||||
next_rank (int, optional): The rank of the recipient of the tensor.
|
||||
"""
|
||||
if not gpc.is_pipeline_last_stage():
|
||||
_communicate(tensor_send_next=output_tensor,
|
||||
next_rank=next_rank,
|
||||
scatter_gather_tensors=scatter_gather_tensors)
|
||||
_communicate(tensor_send_next=output_tensor, next_rank=next_rank, scatter_gather_tensors=scatter_gather_tensors)
|
||||
|
||||
|
||||
def send_backward(input_tensor_grad, prev_rank=None, scatter_gather_tensors=False):
|
||||
@ -338,8 +333,7 @@ def send_forward_backward_recv_forward_backward(output_tensor,
|
||||
Returns:
|
||||
Tuple(Tensor, Tensor): (the input tensor, the input gradient tensor)
|
||||
"""
|
||||
input_tensor, output_tensor_grad = _communicate(
|
||||
tensor_send_next=output_tensor,
|
||||
input_tensor, output_tensor_grad = _communicate(tensor_send_next=output_tensor,
|
||||
tensor_send_prev=input_tensor_grad,
|
||||
recv_prev=recv_prev,
|
||||
recv_next=recv_next,
|
||||
|
Loading…
Reference in New Issue
Block a user