Migrated project

This commit is contained in:
zbian
2021-10-28 18:21:23 +02:00
parent 2ebaefc542
commit 404ecbdcc6
409 changed files with 35853 additions and 0 deletions

View File

@@ -0,0 +1,14 @@
from .collective import all_gather, reduce_scatter, scatter
from .p2p import (send_forward, send_forward_recv_forward, send_backward_recv_forward,
send_backward, send_backward_recv_backward, send_forward_recv_backward,
send_forward_backward_recv_forward_backward, recv_forward, recv_backward)
from .ring import ring_forward
from .utils import send_tensor_meta, recv_tensor_meta
__all__ = [
'all_gather', 'reduce_scatter', 'scatter',
'send_forward', 'send_forward_recv_forward', 'send_forward_backward_recv_forward_backward',
'send_backward', 'send_backward_recv_backward', 'send_backward_recv_forward',
'send_forward_recv_backward', 'recv_backward', 'recv_forward',
'ring_forward', 'send_tensor_meta', 'recv_tensor_meta'
]

View File

@@ -0,0 +1,84 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch
import torch.distributed as dist
from torch import Tensor
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.utils import get_current_device
def all_gather(tensor: Tensor, dim: int,
parallel_mode: ParallelMode) -> Tensor:
"""Gathers all tensors from the parallel group and concatenates them in a
specific dimension.
:param tensor: Tensor to be gathered
:param dim: The dimension concatenating in
:param parallel_mode: Parallel group mode used in this communication
:type tensor: Tensor
:type dim: int
:type parallel_mode: ParallelMode
:return: The tensor generated by all-gather
:rtype: Tensor
"""
depth = gpc.get_world_size(parallel_mode)
temp = tensor.clone()
shape = list(temp.shape)
shape[dim] *= depth
out = torch.empty(shape, dtype=temp.dtype, device=get_current_device())
out = list(torch.chunk(out, depth, dim=dim))
out = [val.contiguous() for val in out]
dist.all_gather(out, temp, group=gpc.get_group(parallel_mode))
out = torch.cat(out, dim=dim)
return out
def reduce_scatter(tensor: Tensor, dim: int,
parallel_mode: ParallelMode) -> Tensor:
"""Reduces all tensors then scatters it in a specific dimension to all
members in the parallel group.
:param tensor: Tensor to be reduced and scattered
:param dim: The dimension scattering in
:param parallel_mode: Parallel group mode used in this communication
:type tensor: Tensor
:type dim: int
:type parallel_mode: ParallelMode
:return: The tensor generated by reduce-scatter
:rtype: Tensor
"""
depth = gpc.get_world_size(parallel_mode)
temp = list(torch.chunk(tensor, depth, dim=dim))
temp = [val.contiguous() for val in temp]
out = torch.empty(temp[0].shape,
dtype=temp[0].dtype,
device=get_current_device())
dist.reduce_scatter(output=out,
input_list=temp,
group=gpc.get_group(parallel_mode))
return out
def scatter(tensor: Tensor, src: int, dim: int,
parallel_mode: ParallelMode) -> Tensor:
"""Scatters in a specific dimension from source rank to all ranks in
the parallel group.
:param tensor: Tensor to be scattered
:param dim: The dimension scattering in
:param parallel_mode: Parallel group mode used in this communication
:type tensor: Tensor
:type dim: int
:type parallel_mode: ParallelMode
:return: The tensor generated by scatter
:rtype: Tensor
"""
depth = gpc.get_world_size(parallel_mode)
temp = tensor.clone()
dist.broadcast(temp, src=src, group=gpc.get_group(parallel_mode))
rank = gpc.get_local_rank(parallel_mode)
out = torch.chunk(temp, depth, dim=dim)[rank].contiguous()
return out

View File

@@ -0,0 +1,333 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch
import torch.distributed as dist
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.utils import get_current_device
def _communicate(tensor_send_next=None,
tensor_send_prev=None,
recv_prev=False,
recv_next=False,
recv_prev_shape=None,
recv_next_shape=None,
prev_rank=None,
next_rank=None,
up_group=None,
down_group=None,
dtype=None):
"""
Adapted from megatron.p2p_communication.
Communicate tensors between stages. Used as helper method in other
communication methods that are used in pipeline schedule.
Takes the following arguments:
tensor_send_next: tensor to send to next rank (no tensor sent if
set to None).
tensor_send_prev: tensor to send to prev rank (no tensor sent if
set to None).
recv_prev: boolean for whether tensor should be received from
previous rank.
recv_next: boolean for whether tensor should be received from
next rank.
Returns:
(tensor_recv_prev, tensor_recv_next)
"""
# Create placeholder tensors for receive in forward and backward directions
# if needed.
tensor_recv_prev = None
tensor_recv_next = None
if recv_prev:
assert recv_prev_shape is not None
tensor_recv_prev = torch.empty(recv_prev_shape,
requires_grad=True,
device=get_current_device(),
dtype=dtype)
if recv_next:
assert recv_next_shape is not None
tensor_recv_next = torch.empty(recv_next_shape,
requires_grad=True,
device=get_current_device(),
dtype=dtype)
if tensor_send_prev is not None or recv_prev:
if prev_rank is None:
prev_rank = gpc.get_prev_global_rank(
ParallelMode.PIPELINE)
if up_group is None:
up_group = gpc.get_group(ParallelMode.PIPELINE_PREV)
if tensor_send_next is not None or recv_next:
if next_rank is None:
next_rank = gpc.get_next_global_rank(
ParallelMode.PIPELINE)
if down_group is None:
down_group = gpc.get_group(ParallelMode.PIPELINE_NEXT)
# rank = dist.get_rank()
rank = gpc.get_global_rank()
ops = []
if tensor_send_prev is not None:
send_prev_op = dist.broadcast(tensor_send_prev,
src=rank,
group=up_group,
async_op=True)
ops.append(send_prev_op)
if tensor_recv_prev is not None:
recv_prev_op = dist.broadcast(tensor_recv_prev,
src=prev_rank,
group=up_group,
async_op=True)
ops.append(recv_prev_op)
if tensor_recv_next is not None:
recv_next_op = dist.broadcast(tensor_recv_next,
src=next_rank,
group=down_group,
async_op=True)
ops.append(recv_next_op)
if tensor_send_next is not None:
send_next_op = dist.broadcast(tensor_send_next,
src=rank,
group=down_group,
async_op=True)
ops.append(send_next_op)
for req in ops:
req.wait()
# To protect against race condition when using batch_isend_irecv().
torch.cuda.synchronize()
return tensor_recv_prev, tensor_recv_next
def recv_forward(input_tensor_shape, prev_rank=None, up_group=None):
"""Receives the input tensor from the previous member in pipeline.
:param input_tensor_shape: The shape of the tensor to be recieved
:param prev_rank: The rank of the source of the tensor
:param up_group: Communication group including the previous member in pipeline parallel group
:type input_tensor_shape: torch.Size
:type prev_rank: int, optional
:type up_group: ProcessGroup, optional
:return: The input tensor in forward step
:rtype: Tensor
"""
if gpc.is_first_rank(ParallelMode.PIPELINE):
input_tensor = None
else:
input_tensor, _ = _communicate(recv_prev=True,
recv_prev_shape=input_tensor_shape,
prev_rank=prev_rank,
up_group=up_group)
return input_tensor
def recv_backward(output_grad_shape, next_rank=None, down_group=None):
"""Receives the grad tensor from the next member in pipeline.
:param output_grad_shape: The shape of the tensor to be recieved
:param next_rank: The rank of the source of the tensor
:param down_group: Communication group including the next member in pipeline parallel group
:type output_grad_shape: torch.Size
:type next_rank: int, optional
:type down_group: ProcessGroup, optional
:return: The grad of output tensor in forward step
:rtype: Tensor
"""
if gpc.is_last_rank(ParallelMode.PIPELINE):
output_tensor_grad = None
else:
_, output_tensor_grad = _communicate(recv_next=True,
recv_next_shape=output_grad_shape,
next_rank=next_rank,
down_group=down_group)
return output_tensor_grad
def send_forward(output_tensor,
next_rank=None,
down_group=None):
"""Sends the input tensor to the next member in pipeline.
:param output_tensor: Tensor to be sent
:param next_rank: The rank of the recipient of the tensor
:param down_group: Communication group including the next member in pipeline parallel group
:type output_tensor: Tensor
:type next_rank: int, optional
:type down_group: ProcessGroup, optional
"""
if not gpc.is_last_rank(ParallelMode.PIPELINE):
_communicate(tensor_send_next=output_tensor,
next_rank=next_rank,
down_group=down_group)
def send_backward(input_tensor_grad,
prev_rank=None,
up_group=None):
"""Sends the grad tensor to the previous member in pipeline.
:param input_tensor_grad: Tensor to be sent
:param prev_rank: The rank of the recipient of the tensor
:param up_group: Communication group including the previous member in pipeline parallel group
:type input_tensor_grad: Tensor
:type prev_rank: int, optional
:type up_group: ProcessGroup, optional
"""
if not gpc.is_first_rank(ParallelMode.PIPELINE):
_communicate(tensor_send_prev=input_tensor_grad,
prev_rank=prev_rank,
up_group=up_group)
def send_forward_recv_backward(output_tensor,
output_grad_shape,
recv_next=True,
next_rank=None,
down_group=None):
"""Batched communication operation. Sends the input tensor to the
next member in pipeline, while recieves the grad tensor from the
next member in pipeline.
:param output_tensor: Tensor to be sent
:param output_grad_shape: The shape of the tensor to be recieved
:type output_tensor: Tensor
:type output_grad_shape: torch.Size
:return: The grad of output tensor in forward step
:rtype: Tensor
"""
if gpc.is_last_rank(ParallelMode.PIPELINE):
output_tensor_grad = None
else:
_, output_tensor_grad = _communicate(tensor_send_next=output_tensor,
recv_next=recv_next,
recv_next_shape=output_grad_shape,
next_rank=next_rank,
down_group=down_group)
return output_tensor_grad
def send_backward_recv_forward(input_tensor_grad,
input_tensor_shape,
recv_prev=True,
prev_rank=None,
up_group=None):
"""Batched communication operation. Sends the grad tensor to the
previous member in pipeline, while recieves the input tensor from the
previous member in pipeline.
:param input_tensor_grad: Tensor to be sent
:param input_tensor_shape: The shape of the tensor to be recieved
:type input_tensor_grad: Tensor
:type input_tensor_shape: torch.Size
:return: The input tensor in forward step
:rtype: Tensor
"""
if gpc.is_first_rank(ParallelMode.PIPELINE):
input_tensor = None
else:
input_tensor, _ = _communicate(tensor_send_prev=input_tensor_grad,
recv_prev=recv_prev,
recv_prev_shape=input_tensor_shape,
prev_rank=prev_rank,
up_group=up_group)
return input_tensor
def send_forward_recv_forward(output_tensor,
input_tensor_shape,
recv_prev=True,
prev_rank=None,
next_rank=None,
up_group=None,
down_group=None):
"""Batched communication operation. Sends the input tensor to the
next member in pipeline, while recieves the input tensor from the
previous member in pipeline.
:param output_tensor: Tensor to be sent
:param input_tensor_shape: The shape of the tensor to be recieved
:type output_tensor: Tensor
:type input_tensor_shape: torch.Size
:return: The input tensor in forward step
:rtype: Tensor
"""
input_tensor, _ = _communicate(tensor_send_next=output_tensor,
recv_prev=recv_prev,
recv_prev_shape=input_tensor_shape,
prev_rank=prev_rank,
next_rank=next_rank,
up_group=up_group,
down_group=down_group)
return input_tensor
def send_backward_recv_backward(input_tensor_grad,
output_grad_shape,
recv_next=True,
prev_rank=None,
next_rank=None,
up_group=None,
down_group=None):
"""Batched communication operation. Sends the grad tensor to the
previous member in pipeline, while recieves the grad tensor from the
next member in pipeline.
:param input_tensor_grad: Tensor to be sent
:param output_grad_shape: The shape of the tensor to be recieved
:type input_tensor_grad: Tensor
:type output_grad_shape: torch.Size
:return: The grad of output tensor in forward step
:rtype: Tensor
"""
_, output_tensor_grad = _communicate(tensor_send_prev=input_tensor_grad,
recv_next=recv_next,
recv_next_shape=output_grad_shape,
prev_rank=prev_rank,
next_rank=next_rank,
up_group=up_group,
down_group=down_group)
return output_tensor_grad
def send_forward_backward_recv_forward_backward(output_tensor,
input_tensor_grad,
input_tensor_shape,
output_grad_shape,
recv_prev=True,
recv_next=True,
prev_rank=None,
next_rank=None,
up_group=None,
down_group=None):
"""Batched communication operation. Sends the input tensor to the next and
the grad tensor to the previous, while recieves the grad tensor from the
next and the input tensor from the previous.
:param output_tensor: Tensor sent to the next
:param input_tensor_grad: Tensor sent to the previous
:param input_tensor_shape: The shape of the tensor recieved from the previous
:param output_grad_shape: The shape of the tensor recieved from the next
:type output_tensor: Tensor
:type input_tensor_grad: Tensor
:type input_tensor_shape: torch.Size
:type output_grad_shape: torch.Size
:return: (the input tensor in forward step, the grad of output tensor in forward step)
:rtype: (Tensor, Tensor)
"""
input_tensor, output_tensor_grad = _communicate(
tensor_send_next=output_tensor,
tensor_send_prev=input_tensor_grad,
recv_prev=recv_prev,
recv_next=recv_next,
recv_prev_shape=input_tensor_shape,
recv_next_shape=output_grad_shape,
prev_rank=prev_rank,
next_rank=next_rank,
up_group=up_group,
down_group=down_group)
return input_tensor, output_tensor_grad

View File

@@ -0,0 +1,54 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.utils import get_current_device, synchronize
def ring_forward(tensor_send_next: torch.Tensor, parallel_mode: ParallelMode):
"""Sends a tensor to the next member and recieves a tensor from the previous member.
This function returns the recieved tensor from the previous member.
:param tensor_send_next: Tensor sent to next member
:param parallel_mode: Parallel group mode used in this communication
:type tensor_send_next: Tensor
:type parallel_mode: ParallelMode
:return: The tensor recieved from the previous
:rtype: Tensor
"""
buffer_shape = tensor_send_next.size()
ops = []
current_rank = gpc.get_global_rank()
tensor_recv_prev = torch.empty(buffer_shape,
requires_grad=True,
device=get_current_device(),
dtype=tensor_send_next.dtype)
# send to next rank
send_next_op = torch.distributed.P2POp(
torch.distributed.isend, tensor_send_next,
gpc.get_next_global_rank(parallel_mode))
ops.append(send_next_op)
# receive from prev rank
recv_prev_op = torch.distributed.P2POp(
torch.distributed.irecv, tensor_recv_prev,
gpc.get_prev_global_rank(parallel_mode))
ops.append(recv_prev_op)
if current_rank % 2 == 0:
ops = ops[::-1]
reqs = torch.distributed.batch_isend_irecv(ops)
for req in reqs:
req.wait()
# To protect against race condition when using batch_isend_irecv().
synchronize()
return tensor_recv_prev

View File

@@ -0,0 +1,73 @@
import torch
import torch.distributed as dist
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.utils import get_current_device
def send_tensor_meta(tensor, need_meta=True, down_group=None):
"""Sends tensor meta information before sending a specific tensor.
Since the recipient must know the shape of the tensor in p2p communications,
meta information of the tensor should be sent before communications. This function
synchronizes with :func:`recv_tensor_meta`.
:param tensor: Tensor to be sent
:param need_meta: If False, meta information won't be sent
:param down_group: Communication group including the next member in pipeline parallel group
:type tensor: Tensor
:type need_meta: bool, optional
:type down_group: ProcessGroup, optional
:return: False
:rtype: bool
"""
if need_meta:
rank = gpc.get_global_rank()
if down_group is None:
down_group = gpc.get_group(ParallelMode.PIPELINE_NEXT)
tensor_kwargs = {'dtype': torch.long, 'device': get_current_device()}
send_shape = torch.tensor(tensor.size(), **tensor_kwargs)
send_ndims = torch.tensor(len(tensor.size()), **tensor_kwargs)
dist.broadcast(send_ndims, src=rank, group=down_group)
dist.broadcast(send_shape, src=rank, group=down_group)
return False
def recv_tensor_meta(tensor_shape, prev_rank=None, up_group=None):
"""Recieves tensor meta information before recieving a specific tensor.
Since the recipient must know the shape of the tensor in p2p communications,
meta information of the tensor should be recieved before communications. This function
synchronizes with :func:`send_tensor_meta`.
:param tensor_shape: The shape of the tensor to be recieved
:param prev_rank: The rank of the source of the tensor
:param up_group: Communication group including the previous member in pipeline parallel group
:type tensor_shape: torch.Size
:type prev_rank: int, optional
:type up_group: ProcessGroup, optional
:return: The shape of the tensor to be recieved
:rtype: torch.Size
"""
if tensor_shape is None:
if prev_rank is None:
prev_rank = gpc.get_prev_global_rank(
ParallelMode.PIPELINE)
if up_group is None:
up_group = gpc.get_group(ParallelMode.PIPELINE_PREV)
tensor_kwargs = {'dtype': torch.long, 'device': get_current_device()}
recv_ndims = torch.empty((), **tensor_kwargs)
dist.broadcast(recv_ndims, src=prev_rank, group=up_group)
recv_shape = torch.empty(recv_ndims, **tensor_kwargs)
dist.broadcast(recv_shape, src=prev_rank, group=up_group)
tensor_shape = torch.Size(recv_shape)
return tensor_shape