add interleaved pipeline, fix naive amp and update pipeline model initializer (#80)

This commit is contained in:
ver217
2021-12-20 23:26:19 +08:00
committed by GitHub
parent 91c327cb44
commit 8f02a88db2
17 changed files with 544 additions and 170 deletions

View File

@@ -63,9 +63,6 @@ def _communicate(tensor_send_next=None,
next_rank = gpc.get_next_global_rank(
ParallelMode.PIPELINE)
# rank = dist.get_rank()
rank = gpc.get_global_rank()
ops = []
if tensor_send_prev is not None:
send_prev_op = dist.P2POp(dist.isend, tensor_send_prev, prev_rank)
@@ -88,7 +85,7 @@ def _communicate(tensor_send_next=None,
return tensor_recv_prev, tensor_recv_next
def recv_forward(input_tensor_shape, prev_rank=None):
def recv_forward(input_tensor_shape, prev_rank=None, dtype=torch.float):
"""Receives the input tensor from the previous member in pipeline.
:param input_tensor_shape: The shape of the tensor to be recieved
@@ -98,16 +95,17 @@ def recv_forward(input_tensor_shape, prev_rank=None):
:return: The input tensor in forward step
:rtype: :class:`torch.Tensor`
"""
if gpc.is_first_rank(ParallelMode.PIPELINE):
if gpc.is_pipeline_first_stage():
input_tensor = None
else:
input_tensor, _ = _communicate(recv_prev=True,
recv_prev_shape=input_tensor_shape,
prev_rank=prev_rank)
prev_rank=prev_rank,
dtype=dtype)
return input_tensor
def recv_backward(output_grad_shape, next_rank=None):
def recv_backward(output_grad_shape, next_rank=None, dtype=torch.float):
"""Receives the grad tensor from the next member in pipeline.
:param output_grad_shape: The shape of the tensor to be recieved
@@ -117,12 +115,13 @@ def recv_backward(output_grad_shape, next_rank=None):
:return: The grad of output tensor in forward step
:rtype: :class:`torch.Tensor`
"""
if gpc.is_last_rank(ParallelMode.PIPELINE):
if gpc.is_pipeline_last_stage():
output_tensor_grad = None
else:
_, output_tensor_grad = _communicate(recv_next=True,
recv_next_shape=output_grad_shape,
next_rank=next_rank)
next_rank=next_rank,
dtype=dtype)
return output_tensor_grad
@@ -134,7 +133,7 @@ def send_forward(output_tensor, next_rank=None):
:type output_tensor: :class:`torch.Tensor`
:type next_rank: int, optional
"""
if not gpc.is_last_rank(ParallelMode.PIPELINE):
if not gpc.is_pipeline_last_stage():
_communicate(tensor_send_next=output_tensor,
next_rank=next_rank)
@@ -147,7 +146,7 @@ def send_backward(input_tensor_grad, prev_rank=None):
:type input_tensor_grad: :class:`torch.Tensor`
:type prev_rank: int, optional
"""
if not gpc.is_first_rank(ParallelMode.PIPELINE):
if not gpc.is_pipeline_first_stage():
_communicate(tensor_send_prev=input_tensor_grad,
prev_rank=prev_rank)
@@ -155,7 +154,8 @@ def send_backward(input_tensor_grad, prev_rank=None):
def send_forward_recv_backward(output_tensor,
output_grad_shape,
recv_next=True,
next_rank=None):
next_rank=None,
dtype=torch.float):
"""Batched communication operation. Sends the input tensor to the
next member in pipeline, while recieves the grad tensor from the
next member in pipeline.
@@ -167,20 +167,22 @@ def send_forward_recv_backward(output_tensor,
:return: The grad of output tensor in forward step
:rtype: :class:`torch.Tensor`
"""
if gpc.is_last_rank(ParallelMode.PIPELINE):
if gpc.is_pipeline_last_stage():
output_tensor_grad = None
else:
_, output_tensor_grad = _communicate(tensor_send_next=output_tensor,
recv_next=recv_next,
recv_next_shape=output_grad_shape,
next_rank=next_rank)
next_rank=next_rank,
dtype=dtype)
return output_tensor_grad
def send_backward_recv_forward(input_tensor_grad,
input_tensor_shape,
recv_prev=True,
prev_rank=None):
prev_rank=None,
dtype=torch.float):
"""Batched communication operation. Sends the grad tensor to the
previous member in pipeline, while recieves the input tensor from the
previous member in pipeline.
@@ -192,13 +194,14 @@ def send_backward_recv_forward(input_tensor_grad,
:return: The input tensor in forward step
:rtype: :class:`torch.Tensor`
"""
if gpc.is_first_rank(ParallelMode.PIPELINE):
if gpc.is_pipeline_first_stage():
input_tensor = None
else:
input_tensor, _ = _communicate(tensor_send_prev=input_tensor_grad,
recv_prev=recv_prev,
recv_prev_shape=input_tensor_shape,
prev_rank=prev_rank)
prev_rank=prev_rank,
dtype=dtype)
return input_tensor
@@ -206,7 +209,8 @@ def send_forward_recv_forward(output_tensor,
input_tensor_shape,
recv_prev=True,
prev_rank=None,
next_rank=None):
next_rank=None,
dtype=torch.float):
"""Batched communication operation. Sends the input tensor to the
next member in pipeline, while recieves the input tensor from the
previous member in pipeline.
@@ -222,7 +226,8 @@ def send_forward_recv_forward(output_tensor,
recv_prev=recv_prev,
recv_prev_shape=input_tensor_shape,
prev_rank=prev_rank,
next_rank=next_rank)
next_rank=next_rank,
dtype=dtype)
return input_tensor
@@ -230,7 +235,8 @@ def send_backward_recv_backward(input_tensor_grad,
output_grad_shape,
recv_next=True,
prev_rank=None,
next_rank=None):
next_rank=None,
dtype=torch.float):
"""Batched communication operation. Sends the grad tensor to the
previous member in pipeline, while recieves the grad tensor from the
next member in pipeline.
@@ -246,7 +252,8 @@ def send_backward_recv_backward(input_tensor_grad,
recv_next=recv_next,
recv_next_shape=output_grad_shape,
prev_rank=prev_rank,
next_rank=next_rank)
next_rank=next_rank,
dtype=dtype)
return output_tensor_grad
@@ -257,7 +264,8 @@ def send_forward_backward_recv_forward_backward(output_tensor,
recv_prev=True,
recv_next=True,
prev_rank=None,
next_rank=None):
next_rank=None,
dtype=torch.float):
"""Batched communication operation. Sends the input tensor to the next and
the grad tensor to the previous, while recieves the grad tensor from the
next and the input tensor from the previous.
@@ -281,5 +289,6 @@ def send_forward_backward_recv_forward_backward(output_tensor,
recv_prev_shape=input_tensor_shape,
recv_next_shape=output_grad_shape,
prev_rank=prev_rank,
next_rank=next_rank)
next_rank=next_rank,
dtype=dtype)
return input_tensor, output_tensor_grad