mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 19:13:01 +00:00
[pipeline]: fix p2p comm, add metadata cache and support llama interleaved pp (#5134)
* test: add more p2p tests * fix: remove send_forward_recv_forward as p2p op list need to use the same group * fix: make send and receive atomic * feat: update P2PComm fn * feat: add metadata cache in 1f1b * feat: add metadata cache in interleaved pp * feat: modify is_xx_stage fn * revert: add _broadcast_object_list * feat: add interleaved pp in llama policy * feat: set NCCL_BUFFSIZE in HybridParallelPlugin
This commit is contained in:
@@ -4,13 +4,13 @@
|
||||
import io
|
||||
import pickle
|
||||
import re
|
||||
from typing import Any, List, Optional, Union
|
||||
from collections import namedtuple
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from packaging.version import Version
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.distributed import distributed_c10d as c10d
|
||||
@@ -20,7 +20,7 @@ from .stage_manager import PipelineStageManager
|
||||
_unpickler = pickle.Unpickler
|
||||
|
||||
|
||||
def _cuda_safe_tensor_to_object(tensor: torch.Tensor, tensor_size: torch.Size) -> object:
|
||||
def _cuda_safe_tensor_to_object(tensor: torch.Tensor, tensor_size: torch.Size) -> Any:
|
||||
"""transform tensor to object with unpickle.
|
||||
Info of the device in bytes stream will be modified into current device before unpickling
|
||||
|
||||
@@ -48,21 +48,7 @@ def _cuda_safe_tensor_to_object(tensor: torch.Tensor, tensor_size: torch.Size) -
|
||||
return unpickle
|
||||
|
||||
|
||||
def check_for_nccl_backend(group):
|
||||
pg = group or c10d._get_default_group()
|
||||
# Gate PG wrapper check on Gloo availability.
|
||||
if c10d._GLOO_AVAILABLE:
|
||||
# It is not expected for PG to be wrapped many times, but support it just
|
||||
# in case
|
||||
while isinstance(pg, c10d._ProcessGroupWrapper):
|
||||
pg = pg.wrapped_pg
|
||||
|
||||
return (
|
||||
c10d.is_nccl_available() and
|
||||
pg.name() == c10d.Backend.NCCL
|
||||
)
|
||||
|
||||
|
||||
# NOTE: FIXME: NPU DOES NOT support isend nor irecv, so broadcast is kept for future use
|
||||
def _broadcast_object_list(
|
||||
object_list: List[Any], src: int, group: ProcessGroup, device: Optional[Union[torch.device, str, int]] = None
|
||||
):
|
||||
@@ -70,13 +56,11 @@ def _broadcast_object_list(
|
||||
The only difference is that object will be move to correct device after unpickled.
|
||||
If local_rank = src, then object list will be sent to rank src. Otherwise, object list will
|
||||
be updated with data sent from rank src.
|
||||
|
||||
Args:
|
||||
object_list (List[Any]): list of object to broadcast
|
||||
src (int): source rank to broadcast
|
||||
dst (int): dst rank to broadcast
|
||||
device (:class:`torch.device`): device to do broadcast. current device in default
|
||||
|
||||
"""
|
||||
|
||||
if c10d._rank_not_in_group(group):
|
||||
@@ -131,7 +115,7 @@ def _broadcast_object_list(
|
||||
|
||||
if my_rank != src:
|
||||
for i, obj_size in enumerate(object_sizes_tensor):
|
||||
obj_view = object_tensor[offset: offset + obj_size]
|
||||
obj_view = object_tensor[offset : offset + obj_size]
|
||||
obj_view = obj_view.type(torch.uint8)
|
||||
if obj_view.device != torch.device("cpu"):
|
||||
obj_view = obj_view.cpu()
|
||||
@@ -149,6 +133,18 @@ def _broadcast_object_list(
|
||||
object_list[i] = unpickle_object
|
||||
|
||||
|
||||
def check_for_nccl_backend(group):
|
||||
pg = group or c10d._get_default_group()
|
||||
# Gate PG wrapper check on Gloo availability.
|
||||
if c10d._GLOO_AVAILABLE:
|
||||
# It is not expected for PG to be wrapped many times, but support it just
|
||||
# in case
|
||||
while isinstance(pg, c10d._ProcessGroupWrapper):
|
||||
pg = pg.wrapped_pg
|
||||
|
||||
return c10d.is_nccl_available() and pg.name() == c10d.Backend.NCCL
|
||||
|
||||
|
||||
def check_device(group):
|
||||
is_nccl_backend = check_for_nccl_backend(group)
|
||||
current_device = None
|
||||
@@ -159,14 +155,14 @@ def check_device(group):
|
||||
return current_device, is_nccl_backend
|
||||
|
||||
|
||||
TensorMetadata = namedtuple('TensorMetadata', ['key', 'shape', 'dtype', 'requires_grad'])
|
||||
TensorMetadata = namedtuple("TensorMetadata", ["key", "shape", "dtype", "requires_grad"])
|
||||
|
||||
|
||||
class P2PDataType(Enum):
|
||||
serialization = 0
|
||||
tensor = 1
|
||||
list = 2
|
||||
dict = 3
|
||||
Serialization = 0
|
||||
Tensor = 1
|
||||
List = 2
|
||||
Dict = 3
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -175,45 +171,71 @@ class P2PMetadata:
|
||||
content: Union[List[TensorMetadata], TensorMetadata, Any]
|
||||
|
||||
|
||||
def filling_ops_queue(obj, comm_op, comm_rank, ops_queue, group):
|
||||
def filling_ops_queue(obj: Any, comm_op: Callable, comm_rank: int, ops_queue: List, group: ProcessGroup):
|
||||
if isinstance(obj, torch.Tensor):
|
||||
obj = obj.contiguous()
|
||||
op_to_add = dist.P2POp(comm_op, obj, comm_rank, group)
|
||||
ops_queue.append(op_to_add)
|
||||
else:
|
||||
for tensor_to_comm in obj:
|
||||
tensor_to_comm = tensor_to_comm.contiguous()
|
||||
op_to_add = dist.P2POp(comm_op, tensor_to_comm, comm_rank, group)
|
||||
ops_queue.append(op_to_add)
|
||||
assert isinstance(tensor_to_comm, torch.Tensor)
|
||||
filling_ops_queue(tensor_to_comm, comm_op, comm_rank, ops_queue, group)
|
||||
|
||||
|
||||
def create_recv_buffer(p2p_metadata: P2PMetadata, current_device):
|
||||
if p2p_metadata.data_type == P2PDataType.tensor:
|
||||
def create_recv_buffer(p2p_metadata: P2PMetadata, current_device: Any):
|
||||
if p2p_metadata.data_type == P2PDataType.Tensor:
|
||||
metadata = p2p_metadata.content
|
||||
tensor_recv = torch.empty(metadata.shape, requires_grad=metadata.requires_grad, device=current_device, dtype=metadata.dtype)
|
||||
tensor_recv = torch.empty(
|
||||
metadata.shape, requires_grad=metadata.requires_grad, device=current_device, dtype=metadata.dtype
|
||||
)
|
||||
return tensor_recv
|
||||
elif p2p_metadata.data_type in (P2PDataType.list, P2PDataType.dict):
|
||||
elif p2p_metadata.data_type in (P2PDataType.List, P2PDataType.Dict):
|
||||
buffer_recv = []
|
||||
for metadata in p2p_metadata.content:
|
||||
tensor_recv = torch.empty(metadata.shape, requires_grad=metadata.requires_grad, device=current_device, dtype=metadata.dtype)
|
||||
tensor_recv = torch.empty(
|
||||
metadata.shape, requires_grad=metadata.requires_grad, device=current_device, dtype=metadata.dtype
|
||||
)
|
||||
buffer_recv.append(tensor_recv)
|
||||
return buffer_recv
|
||||
else:
|
||||
raise ValueError(f"Unknown data_type: {p2p_metadata.data_type}")
|
||||
|
||||
|
||||
def _batch_send_recv_tensor(send_tensor_list, recv_tensor_metadata, send_dst, recv_src, send_group, recv_group, current_device):
|
||||
def create_fast_send_metadata(object: Any) -> P2PMetadata:
|
||||
assert _check_if_fast_send_available(object)
|
||||
if isinstance(object, torch.Tensor):
|
||||
data_type = P2PDataType.Tensor
|
||||
content = TensorMetadata(None, object.shape, object.dtype, object.requires_grad)
|
||||
elif isinstance(object, list):
|
||||
data_type = P2PDataType.List
|
||||
content = [TensorMetadata(None, v.shape, v.dtype, v.requires_grad) for v in object]
|
||||
elif isinstance(object, dict):
|
||||
data_type = P2PDataType.Dict
|
||||
content = [TensorMetadata(k, v.shape, v.dtype, v.requires_grad) for k, v in object.items()]
|
||||
else:
|
||||
raise RuntimeError("Cannot handle object of type {}".format(type(object)))
|
||||
return P2PMetadata(data_type, content)
|
||||
|
||||
|
||||
def _batch_send_recv_tensor(
|
||||
send_tensor_list: Optional[Union[torch.Tensor, List[torch.Tensor]]],
|
||||
recv_tensor_metadata: Optional[P2PMetadata],
|
||||
send_dst: Optional[int],
|
||||
recv_src: Optional[int],
|
||||
send_group: Optional[ProcessGroup],
|
||||
recv_group: Optional[ProcessGroup],
|
||||
current_device: Any,
|
||||
) -> Optional[Union[torch.Tensor, List[torch.Tensor]]]:
|
||||
buffer_recv = None
|
||||
if recv_tensor_metadata is not None:
|
||||
if recv_tensor_metadata is not None and recv_tensor_metadata.data_type != P2PDataType.Serialization:
|
||||
buffer_recv = create_recv_buffer(recv_tensor_metadata, current_device)
|
||||
|
||||
ops = []
|
||||
|
||||
if send_dst is not None:
|
||||
if send_dst is not None and send_tensor_list is not None:
|
||||
assert send_group is not None
|
||||
filling_ops_queue(send_tensor_list, dist.isend, send_dst, ops, send_group)
|
||||
|
||||
if recv_src is not None:
|
||||
assert buffer_recv is not None
|
||||
if recv_src is not None and buffer_recv is not None:
|
||||
assert recv_group is not None
|
||||
filling_ops_queue(buffer_recv, dist.irecv, recv_src, ops, recv_group)
|
||||
|
||||
if len(ops) > 0:
|
||||
@@ -221,24 +243,26 @@ def _batch_send_recv_tensor(send_tensor_list, recv_tensor_metadata, send_dst, re
|
||||
for req in reqs:
|
||||
req.wait()
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Remove synchronization according to Pytorch's documentation
|
||||
# However, the Megatron-LM does synchronization here
|
||||
# https://github.com/microsoft/Megatron-DeepSpeed/blob/ef13d099c2a1609225a4ce4c1a1753cc76dd90a1/megatron/p2p_communication.py#L111-L112
|
||||
# In case there is potential error, uncomment the following `torch.cuda.synchronize()`
|
||||
# torch.cuda.synchronize()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
return buffer_recv
|
||||
|
||||
|
||||
def _send_recv_serialization_object(
|
||||
object: Any,
|
||||
send_dst: Optional[int], recv_src: Optional[int],
|
||||
send_group: Optional[ProcessGroup], recv_group: Optional[ProcessGroup],
|
||||
current_device,
|
||||
is_nccl_backend):
|
||||
object: Any,
|
||||
send_dst: Optional[int],
|
||||
recv_src: Optional[int],
|
||||
send_group: Optional[ProcessGroup],
|
||||
recv_group: Optional[ProcessGroup],
|
||||
current_device: Any,
|
||||
is_nccl_backend: bool,
|
||||
) -> Optional[P2PMetadata]:
|
||||
ops = []
|
||||
|
||||
send_object_tensor = None
|
||||
if object is not None and send_dst is not None:
|
||||
if Version(torch.__version__) >= Version("1.13.0"):
|
||||
@@ -264,10 +288,8 @@ def _send_recv_serialization_object(
|
||||
for req in reqs:
|
||||
req.wait()
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# See the comment in `_batch_send_recv_tensor`
|
||||
# torch.cuda.synchronize()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
ops = []
|
||||
|
||||
@@ -286,52 +308,77 @@ def _send_recv_serialization_object(
|
||||
for req in reqs:
|
||||
req.wait()
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# See the comment in `_batch_send_recv_tensor`
|
||||
# torch.cuda.synchronize()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
if recv_object_tensor is not None and recv_object_size_tensor is not None:
|
||||
recv_object_tensor = recv_object_tensor.type(torch.uint8)
|
||||
if recv_object_tensor.device != torch.device("cpu"):
|
||||
recv_object_tensor = recv_object_tensor.cpu()
|
||||
|
||||
unpickle_object = _cuda_safe_tensor_to_object(
|
||||
recv_object_tensor, recv_object_size_tensor.item())
|
||||
unpickle_object = _cuda_safe_tensor_to_object(recv_object_tensor, recv_object_size_tensor.item())
|
||||
|
||||
if (
|
||||
isinstance(unpickle_object, torch.Tensor)
|
||||
and unpickle_object.device.index != torch.cuda.current_device()
|
||||
):
|
||||
if isinstance(unpickle_object, torch.Tensor) and unpickle_object.device.index != torch.cuda.current_device():
|
||||
unpickle_object = unpickle_object.cuda()
|
||||
|
||||
return unpickle_object
|
||||
|
||||
|
||||
def _check_if_fast_send_available(object):
|
||||
if type(object) is torch.Tensor:
|
||||
def _check_if_fast_send_available(object: Any) -> bool:
|
||||
if isinstance(object, torch.Tensor):
|
||||
return True
|
||||
elif type(object) is list:
|
||||
is_list_of_tensor = all([type(v) is torch.Tensor for v in object])
|
||||
elif isinstance(object, list):
|
||||
is_list_of_tensor = all([isinstance(v, torch.Tensor) for v in object])
|
||||
return is_list_of_tensor
|
||||
elif type(object) is dict:
|
||||
is_dict_of_tensor = all([type(k) is str and type(
|
||||
v) is torch.Tensor for k, v in object.items()])
|
||||
|
||||
elif isinstance(object, dict):
|
||||
is_dict_of_tensor = all([isinstance(k, str) and isinstance(v, torch.Tensor) for k, v in object.items()])
|
||||
return is_dict_of_tensor
|
||||
return False
|
||||
|
||||
|
||||
def _communicate(
|
||||
object,
|
||||
object: Any,
|
||||
send_dst: Optional[int],
|
||||
recv_src: Optional[int],
|
||||
send_group: Optional[ProcessGroup] = None,
|
||||
recv_group: Optional[ProcessGroup] = None,
|
||||
send_metadata: bool = True,
|
||||
metadata_recv: Optional[P2PMetadata] = None,
|
||||
) -> Any:
|
||||
if c10d._rank_not_in_group(send_group) or c10d._rank_not_in_group(recv_group):
|
||||
c10d._warn_not_in_group("_communicate")
|
||||
return
|
||||
"""
|
||||
Send and receive object from send_dst and recv_src respectively
|
||||
|
||||
Args:
|
||||
object (Any): object needed to be sent
|
||||
send_dst (int): rank of the destination
|
||||
recv_src (int): rank of the source
|
||||
send_group (ProcessGroup, optional): process group of sender
|
||||
recv_group (ProcessGroup, optional): process group of receiver
|
||||
send_metadata (bool, optional): whether to send metadata
|
||||
metadata_recv (P2PMetadata, optional): metadata of the object to be received
|
||||
"""
|
||||
assert send_dst is not None or recv_src is not None, "send_dst and recv_src cannot be both None"
|
||||
assert send_dst is None or send_group is not None, "send_group must be specified when send_dst is not None"
|
||||
assert recv_src is None or recv_group is not None, "recv_group must be specified when recv_src is not None"
|
||||
send_metadata = send_metadata or (object is not None and not _check_if_fast_send_available(object))
|
||||
assert (
|
||||
metadata_recv is None or metadata_recv.data_type != P2PDataType.Serialization
|
||||
), "metadata_recv type must not be Serialization"
|
||||
|
||||
# NOTE: send & recv should be atomic operations. However, if we need to send metadata or receive metadata,
|
||||
# we are not able to do that (1. send & recv metadata 2. send & recv). So we need to split the send & recv into two parts in this case.
|
||||
if (send_dst is not None and recv_src is not None) and (send_metadata or metadata_recv is None):
|
||||
_communicate(object, send_dst=send_dst, recv_src=None, send_group=send_group, send_metadata=send_metadata)
|
||||
return _communicate(None, send_dst=None, recv_src=recv_src, recv_group=recv_group, metadata_recv=metadata_recv)
|
||||
|
||||
# NOTE: only the following 5 cases are valid:
|
||||
# 1. send() [needs extra metadata] and no recv()
|
||||
# 2. recv() [needs extra metadata] and no send()
|
||||
# 3. neither send() nor recv() need extra metadata
|
||||
assert not (send_dst is not None and send_metadata) or recv_src is None
|
||||
assert not (recv_src is not None and metadata_recv is None) or send_dst is None
|
||||
assert not (send_dst is not None and recv_src is not None) or (not send_metadata and metadata_recv is not None)
|
||||
assert not c10d._rank_not_in_group(send_group) and not c10d._rank_not_in_group(recv_group)
|
||||
|
||||
current_send_device, is_send_nccl_backend = check_device(send_group)
|
||||
current_recv_device, is_recv_nccl_backend = check_device(recv_group)
|
||||
@@ -341,67 +388,56 @@ def _communicate(
|
||||
assert current_send_device == current_recv_device
|
||||
current_device = current_send_device
|
||||
|
||||
assert (send_dst is not None) or (recv_src is not None)
|
||||
|
||||
can_fast_send = False
|
||||
send_metadata = None
|
||||
if send_dst is not None:
|
||||
can_fast_send = _check_if_fast_send_available(object) and is_nccl_backend
|
||||
if not can_fast_send:
|
||||
send_metadata = P2PMetadata(P2PDataType.serialization, object)
|
||||
else:
|
||||
if type(object) is torch.Tensor:
|
||||
data_type = P2PDataType.tensor
|
||||
content = TensorMetadata(None, object.shape, object.dtype, object.requires_grad)
|
||||
elif type(object) is list:
|
||||
data_type = P2PDataType.list
|
||||
content = []
|
||||
for v in object:
|
||||
content.append(TensorMetadata(None, v.shape, v.dtype, v.requires_grad))
|
||||
elif type(object) is dict:
|
||||
data_type = P2PDataType.dict
|
||||
content = []
|
||||
for k, v in object.items():
|
||||
content.append(TensorMetadata(k, v.shape, v.dtype, v.requires_grad))
|
||||
if (send_dst is not None and send_metadata) or (recv_src is not None and metadata_recv is None):
|
||||
metadata_send = None
|
||||
if send_dst is not None and send_metadata:
|
||||
can_fast_send = _check_if_fast_send_available(object) and is_nccl_backend
|
||||
if not can_fast_send:
|
||||
metadata_send = P2PMetadata(P2PDataType.Serialization, object)
|
||||
else:
|
||||
raise ValueError('Cannot send object of type {}'.format(type(object)))
|
||||
send_metadata = P2PMetadata(data_type, content)
|
||||
metadata_send = create_fast_send_metadata(object)
|
||||
|
||||
recv_metadata = _send_recv_serialization_object(send_metadata, send_dst, recv_src, send_group, recv_group, current_device, is_nccl_backend)
|
||||
if recv_metadata is not None:
|
||||
assert type(recv_metadata) is P2PMetadata
|
||||
if recv_metadata.data_type == P2PDataType.serialization:
|
||||
return recv_metadata.content
|
||||
if not can_fast_send and send_dst is not None:
|
||||
return
|
||||
# Send and receive metadata
|
||||
_metadata_recv = _send_recv_serialization_object(
|
||||
object=metadata_send,
|
||||
send_dst=send_dst if send_metadata else None,
|
||||
recv_src=recv_src if metadata_recv is None else None,
|
||||
send_group=send_group if send_metadata else None,
|
||||
recv_group=recv_group if metadata_recv is None else None,
|
||||
current_device=current_device,
|
||||
is_nccl_backend=is_nccl_backend,
|
||||
)
|
||||
assert metadata_recv is None or _metadata_recv is None
|
||||
metadata_recv = _metadata_recv if metadata_recv is None else metadata_recv
|
||||
|
||||
send_tensor_list = None
|
||||
if type(object) is torch.Tensor:
|
||||
if isinstance(object, torch.Tensor):
|
||||
send_tensor_list = object
|
||||
elif type(object) is list:
|
||||
elif isinstance(object, list):
|
||||
send_tensor_list = object
|
||||
elif type(object) is dict:
|
||||
elif isinstance(object, dict):
|
||||
send_tensor_list = list(object.values())
|
||||
|
||||
recv_buffer = _batch_send_recv_tensor(send_tensor_list, recv_metadata, send_dst, recv_src, send_group, recv_group, current_device)
|
||||
# Send and receive data
|
||||
recv_buffer = _batch_send_recv_tensor(
|
||||
send_tensor_list, metadata_recv, send_dst, recv_src, send_group, recv_group, current_device
|
||||
)
|
||||
|
||||
if recv_metadata is not None:
|
||||
assert recv_buffer is not None
|
||||
if recv_metadata.data_type in [P2PDataType.tensor, P2PDataType.list]:
|
||||
return recv_buffer
|
||||
elif recv_metadata.data_type == P2PDataType.dict:
|
||||
return {
|
||||
k: v
|
||||
for k, v in zip(
|
||||
[m.key for m in recv_metadata.content],
|
||||
recv_buffer,
|
||||
)
|
||||
}
|
||||
if metadata_recv is not None:
|
||||
assert isinstance(metadata_recv, P2PMetadata)
|
||||
if metadata_recv.data_type == P2PDataType.Serialization:
|
||||
return metadata_recv.content
|
||||
else:
|
||||
raise ValueError('Unknown data type {}'.format(recv_metadata.data_type))
|
||||
assert recv_buffer is not None
|
||||
if metadata_recv.data_type in [P2PDataType.Tensor, P2PDataType.List]:
|
||||
return recv_buffer
|
||||
elif metadata_recv.data_type == P2PDataType.Dict:
|
||||
return {k: v for k, v in zip([m.key for m in metadata_recv.content], recv_buffer)}
|
||||
else:
|
||||
raise ValueError("Unknown data type {}".format(metadata_recv.data_type))
|
||||
|
||||
|
||||
def _send_object(object: Any, src: int, dst: int, group: ProcessGroup) -> None:
|
||||
def _send_object(object: Any, src: int, dst: int, group: ProcessGroup, send_metadata: bool) -> None:
|
||||
"""send anything to dst rank
|
||||
|
||||
Args:
|
||||
@@ -411,10 +447,10 @@ def _send_object(object: Any, src: int, dst: int, group: ProcessGroup) -> None:
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
_communicate(object, send_dst=dst, recv_src=None, send_group=group)
|
||||
_communicate(object, send_dst=dst, recv_src=None, send_group=group, send_metadata=send_metadata)
|
||||
|
||||
|
||||
def _recv_object(src: int, dst: int, group: ProcessGroup) -> Any:
|
||||
def _recv_object(src: int, dst: int, group: ProcessGroup, metadata_recv: Optional[P2PMetadata]) -> Any:
|
||||
"""recv anything from src
|
||||
|
||||
Args:
|
||||
@@ -423,7 +459,7 @@ def _recv_object(src: int, dst: int, group: ProcessGroup) -> Any:
|
||||
Returns:
|
||||
Any: Object received from src.
|
||||
"""
|
||||
return _communicate(None, send_dst=None, recv_src=src, recv_group=group)
|
||||
return _communicate(None, send_dst=None, recv_src=src, recv_group=group, metadata_recv=metadata_recv)
|
||||
|
||||
|
||||
def _p2p_comm(
|
||||
@@ -436,7 +472,7 @@ def _p2p_comm(
|
||||
"""
|
||||
Send and recv tensor using P2P communication, used when pipeline size is 2 to solve the race communication.
|
||||
|
||||
Agrs:
|
||||
Args:
|
||||
tensor_send_next (torch.Tensor): tensor to be sent to next stage
|
||||
recv_prev (bool): whether to receive tensor from previous stage
|
||||
peer (int): rank of the peer
|
||||
@@ -467,7 +503,6 @@ def _p2p_comm(
|
||||
group=group,
|
||||
)
|
||||
ops.append(recv_prev_op)
|
||||
|
||||
if len(ops) > 0:
|
||||
reqs = dist.batch_isend_irecv(ops)
|
||||
for req in reqs:
|
||||
@@ -490,7 +525,6 @@ def _p2p_comm(
|
||||
group=group,
|
||||
)
|
||||
ops.append(send_next_op)
|
||||
|
||||
if tensor_recv_prev is not None:
|
||||
recv_prev_op = dist.P2POp(
|
||||
dist.irecv,
|
||||
@@ -510,7 +544,7 @@ class PipelineP2PCommunication:
|
||||
def __init__(self, stage_manager: PipelineStageManager) -> None:
|
||||
self.stage_manager = stage_manager
|
||||
|
||||
def recv_forward(self, prev_rank: int = None) -> Any:
|
||||
def recv_forward(self, prev_rank: Optional[int] = None, metadata_recv: Optional[P2PMetadata] = None) -> Any:
|
||||
"""Copy the forward output from the previous stage in pipeline as the input tensor of this stage.
|
||||
|
||||
Args:
|
||||
@@ -522,11 +556,13 @@ class PipelineP2PCommunication:
|
||||
if prev_rank is None:
|
||||
prev_rank = self.stage_manager.get_prev_rank()
|
||||
cur_rank = self.stage_manager.get_rank()
|
||||
input_tensor = _recv_object(prev_rank, cur_rank, self.stage_manager.get_p2p_process_group(prev_rank, cur_rank))
|
||||
input_tensor = _recv_object(
|
||||
prev_rank, cur_rank, self.stage_manager.get_p2p_process_group(prev_rank, cur_rank), metadata_recv
|
||||
)
|
||||
|
||||
return input_tensor
|
||||
|
||||
def recv_backward(self, next_rank: int = None) -> Any:
|
||||
def recv_backward(self, next_rank: Optional[int] = None, metadata_recv: Optional[P2PMetadata] = None) -> Any:
|
||||
"""Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage.
|
||||
|
||||
Args:
|
||||
@@ -539,12 +575,12 @@ class PipelineP2PCommunication:
|
||||
next_rank = self.stage_manager.get_next_rank()
|
||||
cur_rank = self.stage_manager.get_rank()
|
||||
output_tensor_grad = _recv_object(
|
||||
next_rank, cur_rank, self.stage_manager.get_p2p_process_group(next_rank, cur_rank)
|
||||
next_rank, cur_rank, self.stage_manager.get_p2p_process_group(next_rank, cur_rank), metadata_recv
|
||||
)
|
||||
|
||||
return output_tensor_grad
|
||||
|
||||
def send_forward(self, output_object: Any, next_rank: int = None) -> None:
|
||||
def send_forward(self, output_object: Any, next_rank: Optional[int] = None, send_metadata: bool = True) -> None:
|
||||
"""Sends the input tensor to the next stage in pipeline.
|
||||
|
||||
Args:
|
||||
@@ -554,9 +590,15 @@ class PipelineP2PCommunication:
|
||||
if next_rank is None:
|
||||
next_rank = self.stage_manager.get_next_rank()
|
||||
cur_rank = self.stage_manager.get_rank()
|
||||
_send_object(output_object, cur_rank, next_rank, self.stage_manager.get_p2p_process_group(cur_rank, next_rank))
|
||||
_send_object(
|
||||
output_object,
|
||||
cur_rank,
|
||||
next_rank,
|
||||
self.stage_manager.get_p2p_process_group(cur_rank, next_rank),
|
||||
send_metadata,
|
||||
)
|
||||
|
||||
def send_backward(self, input_object: Any, prev_rank: int = None) -> None:
|
||||
def send_backward(self, input_object: Any, prev_rank: Optional[int] = None, send_metadata: bool = True) -> None:
|
||||
"""Sends the gradient tensor to the previous stage in pipeline.
|
||||
|
||||
Args:
|
||||
@@ -566,9 +608,21 @@ class PipelineP2PCommunication:
|
||||
if prev_rank is None:
|
||||
prev_rank = self.stage_manager.get_prev_rank()
|
||||
cur_rank = self.stage_manager.get_rank()
|
||||
_send_object(input_object, cur_rank, prev_rank, self.stage_manager.get_p2p_process_group(cur_rank, prev_rank))
|
||||
_send_object(
|
||||
input_object,
|
||||
cur_rank,
|
||||
prev_rank,
|
||||
self.stage_manager.get_p2p_process_group(cur_rank, prev_rank),
|
||||
send_metadata,
|
||||
)
|
||||
|
||||
def send_forward_recv_backward(self, input_object: Any, next_rank: int = None) -> Any:
|
||||
def send_forward_recv_backward(
|
||||
self,
|
||||
input_object: Any,
|
||||
next_rank: Optional[int] = None,
|
||||
send_metadata: bool = True,
|
||||
metadata_recv: Optional[P2PMetadata] = None,
|
||||
) -> Any:
|
||||
"""Sends the gradient tensor to and copy the gradient tensor from the next stage in pipeline
|
||||
|
||||
Args:
|
||||
@@ -581,11 +635,22 @@ class PipelineP2PCommunication:
|
||||
cur_rank = self.stage_manager.get_rank()
|
||||
group = self.stage_manager.get_p2p_process_group(cur_rank, next_rank)
|
||||
return _communicate(
|
||||
input_object, next_rank, next_rank,
|
||||
send_group=group, recv_group=group,
|
||||
input_object,
|
||||
next_rank,
|
||||
next_rank,
|
||||
send_group=group,
|
||||
recv_group=group,
|
||||
send_metadata=send_metadata,
|
||||
metadata_recv=metadata_recv,
|
||||
)
|
||||
|
||||
def send_backward_recv_forward(self, input_object: Any, prev_rank: int = None) -> Any:
|
||||
def send_backward_recv_forward(
|
||||
self,
|
||||
input_object: Any,
|
||||
prev_rank: Optional[int] = None,
|
||||
send_metadata: bool = True,
|
||||
metadata_recv: Optional[P2PMetadata] = None,
|
||||
) -> Any:
|
||||
"""Sends the gradient tensor to and copy the gradient tensor from the previous stage in pipeline
|
||||
|
||||
Args:
|
||||
@@ -597,37 +662,22 @@ class PipelineP2PCommunication:
|
||||
|
||||
cur_rank = self.stage_manager.get_rank()
|
||||
group = self.stage_manager.get_p2p_process_group(prev_rank, cur_rank)
|
||||
return _communicate(
|
||||
input_object, prev_rank, prev_rank,
|
||||
send_group=group, recv_group=group,
|
||||
)
|
||||
|
||||
def send_forward_recv_forward(self, input_object: Any, prev_rank: int = None, next_rank: int = None) -> Any:
|
||||
"""Sends the gradient tensor to the previous stage and copy the input tensor from the previous stage in pipeline.
|
||||
|
||||
Args:
|
||||
input_object (Any): Object to be sent.
|
||||
prev_rank (int, optional): The rank of the sender of the tensor
|
||||
next_rank (int, optional): The rank of the recipient of the tensor
|
||||
"""
|
||||
if prev_rank is None:
|
||||
prev_rank = self.stage_manager.get_prev_rank()
|
||||
if next_rank is None:
|
||||
next_rank = self.stage_manager.get_next_rank()
|
||||
|
||||
cur_rank = self.stage_manager.get_rank()
|
||||
recv_group = self.stage_manager.get_p2p_process_group(prev_rank, cur_rank)
|
||||
send_group = self.stage_manager.get_p2p_process_group(cur_rank, next_rank)
|
||||
return _communicate(
|
||||
input_object,
|
||||
send_dst=next_rank,
|
||||
recv_src=prev_rank,
|
||||
send_group=send_group,
|
||||
recv_group=recv_group,
|
||||
prev_rank,
|
||||
prev_rank,
|
||||
send_group=group,
|
||||
recv_group=group,
|
||||
send_metadata=send_metadata,
|
||||
metadata_recv=metadata_recv,
|
||||
)
|
||||
|
||||
def p2p_communicate(
|
||||
self, output_object: Any, recv_pre: bool, peer: int = None, comm_dtype: torch.dtype = torch.float16
|
||||
self,
|
||||
output_object: Any,
|
||||
recv_pre: bool,
|
||||
next_rank: Optional[int] = None,
|
||||
comm_dtype: torch.dtype = torch.float16,
|
||||
) -> None:
|
||||
"""
|
||||
Sends the input tensor to the next stage in pipeline, using `P2Pop` in torch.
|
||||
@@ -636,10 +686,14 @@ class PipelineP2PCommunication:
|
||||
output_object (Any): Object to be sent.
|
||||
next_rank (int, optional): The rank of the recipient of the tensor.
|
||||
"""
|
||||
if peer is None:
|
||||
peer = self.stage_manager.get_next_rank()
|
||||
if next_rank is None:
|
||||
next_rank = self.stage_manager.get_next_rank()
|
||||
cur_rank = self.stage_manager.get_rank()
|
||||
recv_tensor = _p2p_comm(
|
||||
output_object, recv_pre, peer, self.stage_manager.get_p2p_process_group(cur_rank, peer), comm_dtype
|
||||
output_object,
|
||||
recv_pre,
|
||||
next_rank,
|
||||
self.stage_manager.get_p2p_process_group(cur_rank, next_rank),
|
||||
comm_dtype,
|
||||
)
|
||||
return recv_tensor
|
||||
|
Reference in New Issue
Block a user