mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-25 19:55:03 +00:00
update examples and sphnix docs for the new api (#63)
This commit is contained in:
@@ -96,7 +96,7 @@ def recv_forward(input_tensor_shape, prev_rank=None):
|
||||
:type input_tensor_shape: torch.Size
|
||||
:type prev_rank: int, optional
|
||||
:return: The input tensor in forward step
|
||||
:rtype: Tensor
|
||||
:rtype: :class:`torch.Tensor`
|
||||
"""
|
||||
if gpc.is_first_rank(ParallelMode.PIPELINE):
|
||||
input_tensor = None
|
||||
@@ -115,7 +115,7 @@ def recv_backward(output_grad_shape, next_rank=None):
|
||||
:type output_grad_shape: torch.Size
|
||||
:type next_rank: int, optional
|
||||
:return: The grad of output tensor in forward step
|
||||
:rtype: Tensor
|
||||
:rtype: :class:`torch.Tensor`
|
||||
"""
|
||||
if gpc.is_last_rank(ParallelMode.PIPELINE):
|
||||
output_tensor_grad = None
|
||||
@@ -131,7 +131,7 @@ def send_forward(output_tensor, next_rank=None):
|
||||
|
||||
:param output_tensor: Tensor to be sent
|
||||
:param next_rank: The rank of the recipient of the tensor
|
||||
:type output_tensor: Tensor
|
||||
:type output_tensor: :class:`torch.Tensor`
|
||||
:type next_rank: int, optional
|
||||
"""
|
||||
if not gpc.is_last_rank(ParallelMode.PIPELINE):
|
||||
@@ -144,7 +144,7 @@ def send_backward(input_tensor_grad, prev_rank=None):
|
||||
|
||||
:param input_tensor_grad: Tensor to be sent
|
||||
:param prev_rank: The rank of the recipient of the tensor
|
||||
:type input_tensor_grad: Tensor
|
||||
:type input_tensor_grad: :class:`torch.Tensor`
|
||||
:type prev_rank: int, optional
|
||||
"""
|
||||
if not gpc.is_first_rank(ParallelMode.PIPELINE):
|
||||
@@ -162,10 +162,10 @@ def send_forward_recv_backward(output_tensor,
|
||||
|
||||
:param output_tensor: Tensor to be sent
|
||||
:param output_grad_shape: The shape of the tensor to be recieved
|
||||
:type output_tensor: Tensor
|
||||
:type output_grad_shape: torch.Size
|
||||
:type output_tensor: :class:`torch.Tensor`
|
||||
:type output_grad_shape: :class:`torch.Size`
|
||||
:return: The grad of output tensor in forward step
|
||||
:rtype: Tensor
|
||||
:rtype: :class:`torch.Tensor`
|
||||
"""
|
||||
if gpc.is_last_rank(ParallelMode.PIPELINE):
|
||||
output_tensor_grad = None
|
||||
@@ -187,10 +187,10 @@ def send_backward_recv_forward(input_tensor_grad,
|
||||
|
||||
:param input_tensor_grad: Tensor to be sent
|
||||
:param input_tensor_shape: The shape of the tensor to be recieved
|
||||
:type input_tensor_grad: Tensor
|
||||
:type input_tensor_shape: torch.Size
|
||||
:type input_tensor_grad: :class:`torch.Tensor`
|
||||
:type input_tensor_shape: :class:`torch.Size`
|
||||
:return: The input tensor in forward step
|
||||
:rtype: Tensor
|
||||
:rtype: :class:`torch.Tensor`
|
||||
"""
|
||||
if gpc.is_first_rank(ParallelMode.PIPELINE):
|
||||
input_tensor = None
|
||||
@@ -213,10 +213,10 @@ def send_forward_recv_forward(output_tensor,
|
||||
|
||||
:param output_tensor: Tensor to be sent
|
||||
:param input_tensor_shape: The shape of the tensor to be recieved
|
||||
:type output_tensor: Tensor
|
||||
:type input_tensor_shape: torch.Size
|
||||
:type output_tensor: :class:`torch.Tensor`
|
||||
:type input_tensor_shape: :class:`torch.Size`
|
||||
:return: The input tensor in forward step
|
||||
:rtype: Tensor
|
||||
:rtype: :class:`torch.Tensor`
|
||||
"""
|
||||
input_tensor, _ = _communicate(tensor_send_next=output_tensor,
|
||||
recv_prev=recv_prev,
|
||||
@@ -237,10 +237,10 @@ def send_backward_recv_backward(input_tensor_grad,
|
||||
|
||||
:param input_tensor_grad: Tensor to be sent
|
||||
:param output_grad_shape: The shape of the tensor to be recieved
|
||||
:type input_tensor_grad: Tensor
|
||||
:type output_grad_shape: torch.Size
|
||||
:type input_tensor_grad: :class:`torch.Tensor`
|
||||
:type output_grad_shape: :class:`torch.Size`
|
||||
:return: The grad of output tensor in forward step
|
||||
:rtype: Tensor
|
||||
:rtype: :class:`torch.Tensor`
|
||||
"""
|
||||
_, output_tensor_grad = _communicate(tensor_send_prev=input_tensor_grad,
|
||||
recv_next=recv_next,
|
||||
@@ -266,10 +266,10 @@ def send_forward_backward_recv_forward_backward(output_tensor,
|
||||
:param input_tensor_grad: Tensor sent to the previous
|
||||
:param input_tensor_shape: The shape of the tensor recieved from the previous
|
||||
:param output_grad_shape: The shape of the tensor recieved from the next
|
||||
:type output_tensor: Tensor
|
||||
:type input_tensor_grad: Tensor
|
||||
:type input_tensor_shape: torch.Size
|
||||
:type output_grad_shape: torch.Size
|
||||
:type output_tensor: :class:`torch.Tensor`
|
||||
:type input_tensor_grad: :class:`torch.Tensor`
|
||||
:type input_tensor_shape: :class:`torch.Size`
|
||||
:type output_grad_shape: :class:`torch.Size`
|
||||
:return: (the input tensor in forward step, the grad of output tensor in forward step)
|
||||
:rtype: (Tensor, Tensor)
|
||||
"""
|
||||
|
Reference in New Issue
Block a user