add scatter/gather optim for pipeline (#123)

This commit is contained in:
ver217
2022-01-07 13:22:22 +08:00
committed by GitHub
parent 404e6f88ed
commit 293fb40c42
5 changed files with 166 additions and 56 deletions

View File

@@ -1,12 +1,42 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from typing import List, Tuple, Union
import torch
import torch.distributed as dist
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.utils import get_current_device
from functools import reduce
import operator
from .utils import split_tensor_into_1d_equal_chunks, gather_split_1d_tensor
TensorShape = Union[torch.Size, List[int], Tuple[int]]
def _get_tensor_shape(tensor_shape: TensorShape, chunk_tensor: bool = False) -> Tuple[TensorShape, bool]:
"""get the exact tensor shape when communicating and return whether the tensor is a chunk
:param tensor_shape: shape of tensor
:type tensor_shape: TensorShape
:param chunk_tensor: whether to chunk tensor, defaults to False
:type chunk_tensor: bool, optional
:return: exact tensor shape, whether to chunk tensor
:rtype: Tuple[Union[torch.Size, List[int], Tuple[int]], bool]
"""
if chunk_tensor:
tensor_chunk_shape = reduce(operator.mul, tensor_shape, 1)
tensor_parallel_world_size = gpc.get_world_size(ParallelMode.TENSOR)
if tensor_chunk_shape % tensor_parallel_world_size == 0:
tensor_chunk_shape = tensor_chunk_shape // tensor_parallel_world_size
else:
tensor_chunk_shape = tensor_shape
chunk_tensor = False
else:
tensor_chunk_shape = tensor_shape
return tensor_chunk_shape, chunk_tensor
def _communicate(tensor_send_next=None,
@@ -17,7 +47,8 @@ def _communicate(tensor_send_next=None,
recv_next_shape=None,
prev_rank=None,
next_rank=None,
dtype=None):
dtype=None,
scatter_gather_tensors=False):
"""
Adapted from megatron.p2p_communication.
Communicate tensors between stages. Used as helper method in other
@@ -42,13 +73,15 @@ def _communicate(tensor_send_next=None,
if recv_prev:
assert recv_prev_shape is not None
tensor_recv_prev = torch.empty(recv_prev_shape,
recv_prev_chunk_shape, recv_prev_split = _get_tensor_shape(recv_prev_shape, scatter_gather_tensors)
tensor_recv_prev = torch.empty(recv_prev_chunk_shape,
requires_grad=True,
device=get_current_device(),
dtype=dtype)
if recv_next:
assert recv_next_shape is not None
tensor_recv_next = torch.empty(recv_next_shape,
recv_next_chunk_shape, recv_next_split = _get_tensor_shape(recv_next_shape, scatter_gather_tensors)
tensor_recv_next = torch.empty(recv_next_chunk_shape,
requires_grad=True,
device=get_current_device(),
dtype=dtype)
@@ -63,6 +96,16 @@ def _communicate(tensor_send_next=None,
next_rank = gpc.get_next_global_rank(
ParallelMode.PIPELINE)
if tensor_send_prev is not None:
send_prev_split = _get_tensor_shape(tensor_send_prev.shape, scatter_gather_tensors)[1]
if send_prev_split:
tensor_send_prev = split_tensor_into_1d_equal_chunks(tensor_send_prev)
if tensor_send_next is not None:
send_next_split = _get_tensor_shape(tensor_send_next.shape, scatter_gather_tensors)[1]
if send_next_split:
tensor_send_next = split_tensor_into_1d_equal_chunks(tensor_send_next)
ops = []
if tensor_send_prev is not None:
send_prev_op = dist.P2POp(dist.isend, tensor_send_prev, prev_rank)
@@ -82,10 +125,15 @@ def _communicate(tensor_send_next=None,
req.wait()
# To protect against race condition when using batch_isend_irecv().
torch.cuda.synchronize()
if recv_prev and recv_prev_split:
tensor_recv_prev = gather_split_1d_tensor(tensor_recv_prev).view(recv_prev_shape).requires_grad_()
if recv_next and recv_next_split:
tensor_recv_next = gather_split_1d_tensor(tensor_recv_next).view(recv_next_shape).requires_grad_()
return tensor_recv_prev, tensor_recv_next
def recv_forward(input_tensor_shape, prev_rank=None, dtype=torch.float):
def recv_forward(input_tensor_shape, prev_rank=None, dtype=torch.float, scatter_gather_tensors=False):
"""Receives the input tensor from the previous member in pipeline.
:param input_tensor_shape: The shape of the tensor to be recieved
@@ -101,11 +149,12 @@ def recv_forward(input_tensor_shape, prev_rank=None, dtype=torch.float):
input_tensor, _ = _communicate(recv_prev=True,
recv_prev_shape=input_tensor_shape,
prev_rank=prev_rank,
dtype=dtype)
dtype=dtype,
scatter_gather_tensors=scatter_gather_tensors)
return input_tensor
def recv_backward(output_grad_shape, next_rank=None, dtype=torch.float):
def recv_backward(output_grad_shape, next_rank=None, dtype=torch.float, scatter_gather_tensors=False):
"""Receives the grad tensor from the next member in pipeline.
:param output_grad_shape: The shape of the tensor to be recieved
@@ -121,11 +170,12 @@ def recv_backward(output_grad_shape, next_rank=None, dtype=torch.float):
_, output_tensor_grad = _communicate(recv_next=True,
recv_next_shape=output_grad_shape,
next_rank=next_rank,
dtype=dtype)
dtype=dtype,
scatter_gather_tensors=scatter_gather_tensors)
return output_tensor_grad
def send_forward(output_tensor, next_rank=None):
def send_forward(output_tensor, next_rank=None, scatter_gather_tensors=False):
"""Sends the input tensor to the next member in pipeline.
:param output_tensor: Tensor to be sent
@@ -135,10 +185,11 @@ def send_forward(output_tensor, next_rank=None):
"""
if not gpc.is_pipeline_last_stage():
_communicate(tensor_send_next=output_tensor,
next_rank=next_rank)
next_rank=next_rank,
scatter_gather_tensors=scatter_gather_tensors)
def send_backward(input_tensor_grad, prev_rank=None):
def send_backward(input_tensor_grad, prev_rank=None, scatter_gather_tensors=False):
"""Sends the grad tensor to the previous member in pipeline.
:param input_tensor_grad: Tensor to be sent
@@ -148,14 +199,16 @@ def send_backward(input_tensor_grad, prev_rank=None):
"""
if not gpc.is_pipeline_first_stage():
_communicate(tensor_send_prev=input_tensor_grad,
prev_rank=prev_rank)
prev_rank=prev_rank,
scatter_gather_tensors=scatter_gather_tensors)
def send_forward_recv_backward(output_tensor,
output_grad_shape,
recv_next=True,
next_rank=None,
dtype=torch.float):
dtype=torch.float,
scatter_gather_tensors=False):
"""Batched communication operation. Sends the input tensor to the
next member in pipeline, while recieves the grad tensor from the
next member in pipeline.
@@ -174,7 +227,8 @@ def send_forward_recv_backward(output_tensor,
recv_next=recv_next,
recv_next_shape=output_grad_shape,
next_rank=next_rank,
dtype=dtype)
dtype=dtype,
scatter_gather_tensors=scatter_gather_tensors)
return output_tensor_grad
@@ -182,7 +236,8 @@ def send_backward_recv_forward(input_tensor_grad,
input_tensor_shape,
recv_prev=True,
prev_rank=None,
dtype=torch.float):
dtype=torch.float,
scatter_gather_tensors=False):
"""Batched communication operation. Sends the grad tensor to the
previous member in pipeline, while recieves the input tensor from the
previous member in pipeline.
@@ -201,7 +256,8 @@ def send_backward_recv_forward(input_tensor_grad,
recv_prev=recv_prev,
recv_prev_shape=input_tensor_shape,
prev_rank=prev_rank,
dtype=dtype)
dtype=dtype,
scatter_gather_tensors=scatter_gather_tensors)
return input_tensor
@@ -210,7 +266,8 @@ def send_forward_recv_forward(output_tensor,
recv_prev=True,
prev_rank=None,
next_rank=None,
dtype=torch.float):
dtype=torch.float,
scatter_gather_tensors=False):
"""Batched communication operation. Sends the input tensor to the
next member in pipeline, while recieves the input tensor from the
previous member in pipeline.
@@ -227,7 +284,8 @@ def send_forward_recv_forward(output_tensor,
recv_prev_shape=input_tensor_shape,
prev_rank=prev_rank,
next_rank=next_rank,
dtype=dtype)
dtype=dtype,
scatter_gather_tensors=scatter_gather_tensors)
return input_tensor
@@ -236,7 +294,8 @@ def send_backward_recv_backward(input_tensor_grad,
recv_next=True,
prev_rank=None,
next_rank=None,
dtype=torch.float):
dtype=torch.float,
scatter_gather_tensors=False):
"""Batched communication operation. Sends the grad tensor to the
previous member in pipeline, while recieves the grad tensor from the
next member in pipeline.
@@ -253,7 +312,8 @@ def send_backward_recv_backward(input_tensor_grad,
recv_next_shape=output_grad_shape,
prev_rank=prev_rank,
next_rank=next_rank,
dtype=dtype)
dtype=dtype,
scatter_gather_tensors=scatter_gather_tensors)
return output_tensor_grad
@@ -265,7 +325,8 @@ def send_forward_backward_recv_forward_backward(output_tensor,
recv_next=True,
prev_rank=None,
next_rank=None,
dtype=torch.float):
dtype=torch.float,
scatter_gather_tensors=False):
"""Batched communication operation. Sends the input tensor to the next and
the grad tensor to the previous, while recieves the grad tensor from the
next and the input tensor from the previous.
@@ -290,5 +351,6 @@ def send_forward_backward_recv_forward_backward(output_tensor,
recv_next_shape=output_grad_shape,
prev_rank=prev_rank,
next_rank=next_rank,
dtype=dtype)
dtype=dtype,
scatter_gather_tensors=scatter_gather_tensors)
return input_tensor, output_tensor_grad