[pipeline] A more general _communicate in p2p (#5062)

* A more general _communicate

* feat: finish tree_flatten version p2p

* fix: update p2p api calls

---------

Co-authored-by: Wenhao Chen <cwher@outlook.com>
This commit is contained in:
Elsa Granger 2024-01-08 15:37:27 +08:00 committed by GitHub
parent 7bc6969ce6
commit d565df3821
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 104 additions and 136 deletions

View File

@ -5,20 +5,17 @@ import io
import pickle import pickle
import re import re
from collections import namedtuple from collections import namedtuple
from dataclasses import dataclass from typing import Any, Callable, List, Optional, Tuple, Union
from enum import Enum
from typing import Any, Callable, List, Optional, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from packaging.version import Version from packaging.version import Version
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from torch.distributed import distributed_c10d as c10d from torch.distributed import distributed_c10d as c10d
from torch.utils._pytree import tree_flatten, tree_unflatten
from .stage_manager import PipelineStageManager from .stage_manager import PipelineStageManager
_unpickler = pickle.Unpickler
def _cuda_safe_tensor_to_object(tensor: torch.Tensor, tensor_size: torch.Size) -> Any: def _cuda_safe_tensor_to_object(tensor: torch.Tensor, tensor_size: torch.Size) -> Any:
"""transform tensor to object with unpickle. """transform tensor to object with unpickle.
@ -42,7 +39,7 @@ def _cuda_safe_tensor_to_object(tensor: torch.Tensor, tensor_size: torch.Size) -
buf = bytes(buf_array) buf = bytes(buf_array)
io_bytes = io.BytesIO(buf) io_bytes = io.BytesIO(buf)
byte_pickler = _unpickler(io_bytes) byte_pickler = pickle.Unpickler(io_bytes)
unpickle = byte_pickler.load() unpickle = byte_pickler.load()
return unpickle return unpickle
@ -67,7 +64,7 @@ def _broadcast_object_list(
c10d._warn_not_in_group("broadcast_object_list") c10d._warn_not_in_group("broadcast_object_list")
return return
is_nccl_backend = check_for_nccl_backend(group) is_nccl_backend = _check_for_nccl_backend(group)
current_device = None current_device = None
if device is not None: if device is not None:
@ -133,45 +130,61 @@ def _broadcast_object_list(
object_list[i] = unpickle_object object_list[i] = unpickle_object
def check_for_nccl_backend(group): def _check_for_nccl_backend(group):
pg = group or c10d._get_default_group() pg = group or c10d._get_default_group()
# Gate PG wrapper check on Gloo availability. # Gate PG wrapper check on Gloo availability.
if c10d._GLOO_AVAILABLE: if c10d._GLOO_AVAILABLE:
# It is not expected for PG to be wrapped many times, but support it just # It is not expected for PG to be wrapped many times, but support it just in case
# in case
while isinstance(pg, c10d._ProcessGroupWrapper): while isinstance(pg, c10d._ProcessGroupWrapper):
pg = pg.wrapped_pg pg = pg.wrapped_pg
return c10d.is_nccl_available() and pg.name() == c10d.Backend.NCCL return c10d.is_nccl_available() and pg.name() == c10d.Backend.NCCL
def check_device(group): def _check_device(group):
is_nccl_backend = check_for_nccl_backend(group) is_nccl_backend = _check_for_nccl_backend(group)
current_device = None
current_device = torch.device("cpu") current_device = torch.device("cpu")
if is_nccl_backend: if is_nccl_backend:
current_device = torch.device("cuda", torch.cuda.current_device()) current_device = torch.device("cuda", torch.cuda.current_device())
return current_device, is_nccl_backend return current_device, is_nccl_backend
TensorMetadata = namedtuple("TensorMetadata", ["key", "shape", "dtype", "requires_grad"]) TensorMetadata = namedtuple("TensorMetadata", ["shape", "dtype", "requires_grad"])
P2PMetadata = namedtuple("P2PMetadata", ["tree_spec", "tensor_metadata", "non_tensor_obj_idx", "non_tensor_objs"])
class P2PDataType(Enum): def create_send_metadata(
Serialization = 0 object: Any, strict: bool = True, return_tensor: bool = False
Tensor = 1 ) -> Union[P2PMetadata, Tuple[P2PMetadata, List[torch.Tensor]]]:
List = 2 """
Dict = 3 Args:
object (Any): object needed to be sent
strict (bool, optional): whether to check if the object is supported for fast send
return_tensor (bool, optional): whether to return tensor objects
"""
objs, tree_spec = tree_flatten(object)
tensor_metadata, tensor_objs = [], []
non_tensor_obj_idx, non_tensor_objs = [], []
for idx, obj in enumerate(objs):
if isinstance(obj, torch.Tensor):
tensor_objs.append(obj)
tensor_metadata.append(TensorMetadata(obj.shape, obj.dtype, obj.requires_grad))
else:
non_tensor_obj_idx.append(idx)
non_tensor_objs.append(obj)
assert not strict or len(non_tensor_objs) == 0, "Only support tensor for fast send"
metadata = P2PMetadata(tree_spec, tensor_metadata, non_tensor_obj_idx, non_tensor_objs)
return metadata if not return_tensor else (metadata, tensor_objs)
@dataclass def _filling_ops_queue(
class P2PMetadata: obj: Union[torch.Tensor, List[torch.Tensor]],
data_type: P2PDataType comm_op: Callable,
content: Union[List[TensorMetadata], TensorMetadata, Any] comm_rank: int,
ops_queue: List,
group: ProcessGroup,
def filling_ops_queue(obj: Any, comm_op: Callable, comm_rank: int, ops_queue: List, group: ProcessGroup): ):
if isinstance(obj, torch.Tensor): if isinstance(obj, torch.Tensor):
obj = obj.contiguous() obj = obj.contiguous()
op_to_add = dist.P2POp(comm_op, obj, comm_rank, group) op_to_add = dist.P2POp(comm_op, obj, comm_rank, group)
@ -179,47 +192,22 @@ def filling_ops_queue(obj: Any, comm_op: Callable, comm_rank: int, ops_queue: Li
else: else:
for tensor_to_comm in obj: for tensor_to_comm in obj:
assert isinstance(tensor_to_comm, torch.Tensor) assert isinstance(tensor_to_comm, torch.Tensor)
filling_ops_queue(tensor_to_comm, comm_op, comm_rank, ops_queue, group) _filling_ops_queue(tensor_to_comm, comm_op, comm_rank, ops_queue, group)
def create_recv_buffer(p2p_metadata: P2PMetadata, current_device: Any): def _create_recv_buffer(tensor_metadata: List[TensorMetadata], current_device) -> List[torch.Tensor]:
if p2p_metadata.data_type == P2PDataType.Tensor:
metadata = p2p_metadata.content
tensor_recv = torch.empty(
metadata.shape, requires_grad=metadata.requires_grad, device=current_device, dtype=metadata.dtype
)
return tensor_recv
elif p2p_metadata.data_type in (P2PDataType.List, P2PDataType.Dict):
buffer_recv = [] buffer_recv = []
for metadata in p2p_metadata.content: for metadata in tensor_metadata:
tensor_recv = torch.empty( tensor_recv = torch.empty(
metadata.shape, requires_grad=metadata.requires_grad, device=current_device, dtype=metadata.dtype metadata.shape, requires_grad=metadata.requires_grad, device=current_device, dtype=metadata.dtype
) )
buffer_recv.append(tensor_recv) buffer_recv.append(tensor_recv)
return buffer_recv return buffer_recv
else:
raise ValueError(f"Unknown data_type: {p2p_metadata.data_type}")
def create_fast_send_metadata(object: Any) -> P2PMetadata:
assert _check_if_fast_send_available(object)
if isinstance(object, torch.Tensor):
data_type = P2PDataType.Tensor
content = TensorMetadata(None, object.shape, object.dtype, object.requires_grad)
elif isinstance(object, list):
data_type = P2PDataType.List
content = [TensorMetadata(None, v.shape, v.dtype, v.requires_grad) for v in object]
elif isinstance(object, dict):
data_type = P2PDataType.Dict
content = [TensorMetadata(k, v.shape, v.dtype, v.requires_grad) for k, v in object.items()]
else:
raise RuntimeError("Cannot handle object of type {}".format(type(object)))
return P2PMetadata(data_type, content)
def _batch_send_recv_tensor( def _batch_send_recv_tensor(
send_tensor_list: Optional[Union[torch.Tensor, List[torch.Tensor]]], send_tensor_list: Optional[List[torch.Tensor]],
recv_tensor_metadata: Optional[P2PMetadata], recv_tensor_metadata: Optional[List[TensorMetadata]],
send_dst: Optional[int], send_dst: Optional[int],
recv_src: Optional[int], recv_src: Optional[int],
send_group: Optional[ProcessGroup], send_group: Optional[ProcessGroup],
@ -227,16 +215,16 @@ def _batch_send_recv_tensor(
current_device: Any, current_device: Any,
) -> Optional[Union[torch.Tensor, List[torch.Tensor]]]: ) -> Optional[Union[torch.Tensor, List[torch.Tensor]]]:
buffer_recv = None buffer_recv = None
if recv_tensor_metadata is not None and recv_tensor_metadata.data_type != P2PDataType.Serialization: if recv_tensor_metadata is not None:
buffer_recv = create_recv_buffer(recv_tensor_metadata, current_device) buffer_recv = _create_recv_buffer(recv_tensor_metadata, current_device)
ops = [] ops = []
if send_dst is not None and send_tensor_list is not None: if send_dst is not None and send_tensor_list is not None:
assert send_group is not None assert send_group is not None
filling_ops_queue(send_tensor_list, dist.isend, send_dst, ops, send_group) _filling_ops_queue(send_tensor_list, dist.isend, send_dst, ops, send_group)
if recv_src is not None and buffer_recv is not None: if recv_src is not None and buffer_recv is not None:
assert recv_group is not None assert recv_group is not None
filling_ops_queue(buffer_recv, dist.irecv, recv_src, ops, recv_group) _filling_ops_queue(buffer_recv, dist.irecv, recv_src, ops, recv_group)
if len(ops) > 0: if len(ops) > 0:
reqs = dist.batch_isend_irecv(ops) reqs = dist.batch_isend_irecv(ops)
@ -247,13 +235,13 @@ def _batch_send_recv_tensor(
# However, the Megatron-LM does synchronization here # However, the Megatron-LM does synchronization here
# https://github.com/microsoft/Megatron-DeepSpeed/blob/ef13d099c2a1609225a4ce4c1a1753cc76dd90a1/megatron/p2p_communication.py#L111-L112 # https://github.com/microsoft/Megatron-DeepSpeed/blob/ef13d099c2a1609225a4ce4c1a1753cc76dd90a1/megatron/p2p_communication.py#L111-L112
# In case there is potential error, uncomment the following `torch.cuda.synchronize()` # In case there is potential error, uncomment the following `torch.cuda.synchronize()`
torch.cuda.synchronize() # torch.cuda.synchronize()
return buffer_recv return buffer_recv
def _send_recv_serialization_object( def _send_recv_serialization_object(
object: Any, object: Optional[P2PMetadata],
send_dst: Optional[int], send_dst: Optional[int],
recv_src: Optional[int], recv_src: Optional[int],
send_group: Optional[ProcessGroup], send_group: Optional[ProcessGroup],
@ -274,14 +262,14 @@ def _send_recv_serialization_object(
send_object_size_tensor = send_object_size_tensor.to(current_device) send_object_size_tensor = send_object_size_tensor.to(current_device)
send_object_tensor = send_object_tensor.to(current_device) send_object_tensor = send_object_tensor.to(current_device)
filling_ops_queue(send_object_size_tensor, dist.isend, send_dst, ops, send_group) _filling_ops_queue(send_object_size_tensor, dist.isend, send_dst, ops, send_group)
recv_object_size_tensor = None recv_object_size_tensor = None
if recv_src is not None: if recv_src is not None:
recv_object_size_tensor = torch.empty(1, dtype=torch.long) recv_object_size_tensor = torch.empty(1, dtype=torch.long)
if is_nccl_backend: if is_nccl_backend:
recv_object_size_tensor = recv_object_size_tensor.to(current_device) recv_object_size_tensor = recv_object_size_tensor.to(current_device)
filling_ops_queue(recv_object_size_tensor, dist.irecv, recv_src, ops, recv_group) _filling_ops_queue(recv_object_size_tensor, dist.irecv, recv_src, ops, recv_group)
if len(ops) > 0: if len(ops) > 0:
reqs = dist.batch_isend_irecv(ops) reqs = dist.batch_isend_irecv(ops)
@ -289,19 +277,19 @@ def _send_recv_serialization_object(
req.wait() req.wait()
# See the comment in `_batch_send_recv_tensor` # See the comment in `_batch_send_recv_tensor`
torch.cuda.synchronize() # torch.cuda.synchronize()
ops = [] ops = []
if send_dst is not None and send_object_tensor is not None: if send_dst is not None and send_object_tensor is not None:
filling_ops_queue(send_object_tensor, dist.isend, send_dst, ops, send_group) _filling_ops_queue(send_object_tensor, dist.isend, send_dst, ops, send_group)
recv_object_tensor = None recv_object_tensor = None
if recv_src is not None and recv_object_size_tensor is not None: if recv_src is not None and recv_object_size_tensor is not None:
recv_object_tensor = torch.empty(recv_object_size_tensor.item(), dtype=torch.uint8) recv_object_tensor = torch.empty(recv_object_size_tensor.item(), dtype=torch.uint8)
if is_nccl_backend: if is_nccl_backend:
recv_object_tensor = recv_object_tensor.to(current_device) recv_object_tensor = recv_object_tensor.to(current_device)
filling_ops_queue(recv_object_tensor, dist.irecv, recv_src, ops, recv_group) _filling_ops_queue(recv_object_tensor, dist.irecv, recv_src, ops, recv_group)
if len(ops) > 0: if len(ops) > 0:
reqs = dist.batch_isend_irecv(ops) reqs = dist.batch_isend_irecv(ops)
@ -309,7 +297,7 @@ def _send_recv_serialization_object(
req.wait() req.wait()
# See the comment in `_batch_send_recv_tensor` # See the comment in `_batch_send_recv_tensor`
torch.cuda.synchronize() # torch.cuda.synchronize()
if recv_object_tensor is not None and recv_object_size_tensor is not None: if recv_object_tensor is not None and recv_object_size_tensor is not None:
recv_object_tensor = recv_object_tensor.type(torch.uint8) recv_object_tensor = recv_object_tensor.type(torch.uint8)
@ -324,18 +312,6 @@ def _send_recv_serialization_object(
return unpickle_object return unpickle_object
def _check_if_fast_send_available(object: Any) -> bool:
if isinstance(object, torch.Tensor):
return True
elif isinstance(object, list):
is_list_of_tensor = all([isinstance(v, torch.Tensor) for v in object])
return is_list_of_tensor
elif isinstance(object, dict):
is_dict_of_tensor = all([isinstance(k, str) and isinstance(v, torch.Tensor) for k, v in object.items()])
return is_dict_of_tensor
return False
def _communicate( def _communicate(
object: Any, object: Any,
send_dst: Optional[int], send_dst: Optional[int],
@ -361,10 +337,15 @@ def _communicate(
assert send_dst is not None or recv_src is not None, "send_dst and recv_src cannot be both None" assert send_dst is not None or recv_src is not None, "send_dst and recv_src cannot be both None"
assert send_dst is None or send_group is not None, "send_group must be specified when send_dst is not None" assert send_dst is None or send_group is not None, "send_group must be specified when send_dst is not None"
assert recv_src is None or recv_group is not None, "recv_group must be specified when recv_src is not None" assert recv_src is None or recv_group is not None, "recv_group must be specified when recv_src is not None"
send_metadata = send_metadata or (object is not None and not _check_if_fast_send_available(object))
assert ( assert (
metadata_recv is None or metadata_recv.data_type != P2PDataType.Serialization metadata_recv is None or len(metadata_recv.non_tensor_obj_idx) == 0
), "metadata_recv type must not be Serialization" ), "metadata_recv should not contain non-tensor objects"
metadata_send, tensor_objs = None, None
if object is not None:
# NOTE: if object contains non-tensor objects, we have to send metadata
metadata_send, tensor_objs = create_send_metadata(object, strict=False, return_tensor=True)
send_metadata = send_metadata or len(metadata_send.non_tensor_obj_idx) > 0
# NOTE: send & recv should be atomic operations. However, if we need to send metadata or receive metadata, # NOTE: send & recv should be atomic operations. However, if we need to send metadata or receive metadata,
# we are not able to do that (1. send & recv metadata 2. send & recv). So we need to split the send & recv into two parts in this case. # we are not able to do that (1. send & recv metadata 2. send & recv). So we need to split the send & recv into two parts in this case.
@ -372,9 +353,13 @@ def _communicate(
assert send_prior_fallback is not None, "Priority must be set if fallback happens" assert send_prior_fallback is not None, "Priority must be set if fallback happens"
if send_prior_fallback: if send_prior_fallback:
_communicate(object, send_dst=send_dst, recv_src=None, send_group=send_group, send_metadata=send_metadata) _communicate(object, send_dst=send_dst, recv_src=None, send_group=send_group, send_metadata=send_metadata)
return _communicate(None, send_dst=None, recv_src=recv_src, recv_group=recv_group, metadata_recv=metadata_recv) return _communicate(
None, send_dst=None, recv_src=recv_src, recv_group=recv_group, metadata_recv=metadata_recv
)
else: else:
recv_data = _communicate(None, send_dst=None, recv_src=recv_src, recv_group=recv_group, metadata_recv=metadata_recv) recv_data = _communicate(
None, send_dst=None, recv_src=recv_src, recv_group=recv_group, metadata_recv=metadata_recv
)
_communicate(object, send_dst=send_dst, recv_src=None, send_group=send_group, send_metadata=send_metadata) _communicate(object, send_dst=send_dst, recv_src=None, send_group=send_group, send_metadata=send_metadata)
return recv_data return recv_data
@ -387,8 +372,8 @@ def _communicate(
assert not (send_dst is not None and recv_src is not None) or (not send_metadata and metadata_recv is not None) assert not (send_dst is not None and recv_src is not None) or (not send_metadata and metadata_recv is not None)
assert not c10d._rank_not_in_group(send_group) and not c10d._rank_not_in_group(recv_group) assert not c10d._rank_not_in_group(send_group) and not c10d._rank_not_in_group(recv_group)
current_send_device, is_send_nccl_backend = check_device(send_group) current_send_device, is_send_nccl_backend = _check_device(send_group)
current_recv_device, is_recv_nccl_backend = check_device(recv_group) current_recv_device, is_recv_nccl_backend = _check_device(recv_group)
is_nccl_backend = is_send_nccl_backend and is_recv_nccl_backend is_nccl_backend = is_send_nccl_backend and is_recv_nccl_backend
@ -396,14 +381,6 @@ def _communicate(
current_device = current_send_device current_device = current_send_device
if (send_dst is not None and send_metadata) or (recv_src is not None and metadata_recv is None): if (send_dst is not None and send_metadata) or (recv_src is not None and metadata_recv is None):
metadata_send = None
if send_dst is not None and send_metadata:
can_fast_send = _check_if_fast_send_available(object) and is_nccl_backend
if not can_fast_send:
metadata_send = P2PMetadata(P2PDataType.Serialization, object)
else:
metadata_send = create_fast_send_metadata(object)
# Send and receive metadata # Send and receive metadata
_metadata_recv = _send_recv_serialization_object( _metadata_recv = _send_recv_serialization_object(
object=metadata_send, object=metadata_send,
@ -417,31 +394,26 @@ def _communicate(
assert metadata_recv is None or _metadata_recv is None assert metadata_recv is None or _metadata_recv is None
metadata_recv = _metadata_recv if metadata_recv is None else metadata_recv metadata_recv = _metadata_recv if metadata_recv is None else metadata_recv
send_tensor_list = None
if isinstance(object, torch.Tensor):
send_tensor_list = object
elif isinstance(object, list):
send_tensor_list = object
elif isinstance(object, dict):
send_tensor_list = list(object.values())
# Send and receive data # Send and receive data
recv_buffer = _batch_send_recv_tensor( recv_tensor_metadata = None if metadata_recv is None else metadata_recv.tensor_metadata
send_tensor_list, metadata_recv, send_dst, recv_src, send_group, recv_group, current_device recv_tensor_objs = _batch_send_recv_tensor(
tensor_objs, recv_tensor_metadata, send_dst, recv_src, send_group, recv_group, current_device
) )
if metadata_recv is not None: if metadata_recv is not None:
assert isinstance(metadata_recv, P2PMetadata) assert isinstance(metadata_recv, P2PMetadata)
if metadata_recv.data_type == P2PDataType.Serialization: tree_spec = metadata_recv.tree_spec
return metadata_recv.content non_tensor_obj_idx = metadata_recv.non_tensor_obj_idx
else: non_tensor_objs = metadata_recv.non_tensor_objs
assert recv_buffer is not None
if metadata_recv.data_type in [P2PDataType.Tensor, P2PDataType.List]: if recv_tensor_objs is None:
return recv_buffer recv_tensor_objs = []
elif metadata_recv.data_type == P2PDataType.Dict:
return {k: v for k, v in zip([m.key for m in metadata_recv.content], recv_buffer)} for idx in non_tensor_obj_idx:
else: recv_tensor_objs.insert(idx, non_tensor_objs.pop(0))
raise ValueError("Unknown data type {}".format(metadata_recv.data_type)) recv_object = tree_unflatten(recv_tensor_objs, tree_spec)
return recv_object
def _send_object(object: Any, src: int, dst: int, group: ProcessGroup, **kwargs) -> None: def _send_object(object: Any, src: int, dst: int, group: ProcessGroup, **kwargs) -> None:

View File

@ -7,7 +7,7 @@ from torch.nn import Module, ModuleList
from torch.utils._pytree import tree_map from torch.utils._pytree import tree_map
from colossalai.interface import OptimizerWrapper from colossalai.interface import OptimizerWrapper
from colossalai.pipeline.p2p import PipelineP2PCommunication, create_fast_send_metadata from colossalai.pipeline.p2p import PipelineP2PCommunication, create_send_metadata
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.utils.device import get_current_device from colossalai.utils.device import get_current_device
@ -130,7 +130,7 @@ class InterleavedSchedule(PipelineSchedule):
if not self.stage_manager.is_first_stage(): if not self.stage_manager.is_first_stage():
input_tensor = self.comm.recv_forward(prev_rank, metadata_recv=self.tensor_metadata_recv) input_tensor = self.comm.recv_forward(prev_rank, metadata_recv=self.tensor_metadata_recv)
if self.enable_metadata_cache and self.tensor_metadata_recv is None: if self.enable_metadata_cache and self.tensor_metadata_recv is None:
self.tensor_metadata_recv = create_fast_send_metadata(input_tensor) self.tensor_metadata_recv = create_send_metadata(input_tensor)
return input_tensor return input_tensor
@ -149,7 +149,7 @@ class InterleavedSchedule(PipelineSchedule):
if not self.stage_manager.is_last_stage(): if not self.stage_manager.is_last_stage():
output_tensor_grad = self.comm.recv_backward(next_rank, metadata_recv=self.grad_metadata_recv) output_tensor_grad = self.comm.recv_backward(next_rank, metadata_recv=self.grad_metadata_recv)
if self.enable_metadata_cache and self.grad_metadata_recv is None: if self.enable_metadata_cache and self.grad_metadata_recv is None:
self.grad_metadata_recv = create_fast_send_metadata(output_tensor_grad) self.grad_metadata_recv = create_send_metadata(output_tensor_grad)
return output_tensor_grad return output_tensor_grad
@ -206,7 +206,7 @@ class InterleavedSchedule(PipelineSchedule):
) )
self.send_tensor_metadata = not self.enable_metadata_cache self.send_tensor_metadata = not self.enable_metadata_cache
if self.enable_metadata_cache and self.grad_metadata_recv is None: if self.enable_metadata_cache and self.grad_metadata_recv is None:
self.grad_metadata_recv = create_fast_send_metadata(output_tensor_grad) self.grad_metadata_recv = create_send_metadata(output_tensor_grad)
return output_tensor_grad return output_tensor_grad
# send only or recv only # send only or recv only
@ -238,7 +238,7 @@ class InterleavedSchedule(PipelineSchedule):
) )
self.send_grad_metadata = not self.enable_metadata_cache self.send_grad_metadata = not self.enable_metadata_cache
if self.enable_metadata_cache and self.tensor_metadata_recv is None: if self.enable_metadata_cache and self.tensor_metadata_recv is None:
self.tensor_metadata_recv = create_fast_send_metadata(input_tensor) self.tensor_metadata_recv = create_send_metadata(input_tensor)
return input_tensor return input_tensor
# send only or recv only # send only or recv only

View File

@ -7,7 +7,7 @@ from torch.nn import Module
from torch.utils._pytree import tree_map from torch.utils._pytree import tree_map
from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.pipeline.p2p import PipelineP2PCommunication, create_fast_send_metadata from colossalai.pipeline.p2p import PipelineP2PCommunication, create_send_metadata
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.utils.device import get_current_device from colossalai.utils.device import get_current_device
@ -121,7 +121,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
if not self.stage_manager.is_first_stage(): if not self.stage_manager.is_first_stage():
input_tensor = self.comm.recv_forward(prev_rank, metadata_recv=self.tensor_metadata_recv) input_tensor = self.comm.recv_forward(prev_rank, metadata_recv=self.tensor_metadata_recv)
if self.enable_metadata_cache and self.tensor_metadata_recv is None: if self.enable_metadata_cache and self.tensor_metadata_recv is None:
self.tensor_metadata_recv = create_fast_send_metadata(input_tensor) self.tensor_metadata_recv = create_send_metadata(input_tensor)
return input_tensor return input_tensor
@ -138,7 +138,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
if not self.stage_manager.is_last_stage(): if not self.stage_manager.is_last_stage():
output_tensor_grad = self.comm.recv_backward(next_rank, metadata_recv=self.grad_metadata_recv) output_tensor_grad = self.comm.recv_backward(next_rank, metadata_recv=self.grad_metadata_recv)
if self.enable_metadata_cache and self.grad_metadata_recv is None: if self.enable_metadata_cache and self.grad_metadata_recv is None:
self.grad_metadata_recv = create_fast_send_metadata(output_tensor_grad) self.grad_metadata_recv = create_send_metadata(output_tensor_grad)
return output_tensor_grad return output_tensor_grad
@ -188,7 +188,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
) )
self.send_tensor_metadata = not self.enable_metadata_cache self.send_tensor_metadata = not self.enable_metadata_cache
if self.enable_metadata_cache and self.grad_metadata_recv is None: if self.enable_metadata_cache and self.grad_metadata_recv is None:
self.grad_metadata_recv = create_fast_send_metadata(output_tensor_grad) self.grad_metadata_recv = create_send_metadata(output_tensor_grad)
return output_tensor_grad return output_tensor_grad
@ -214,7 +214,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
) )
self.send_grad_metadata = not self.enable_metadata_cache self.send_grad_metadata = not self.enable_metadata_cache
if self.enable_metadata_cache and self.tensor_metadata_recv is None: if self.enable_metadata_cache and self.tensor_metadata_recv is None:
self.tensor_metadata_recv = create_fast_send_metadata(input_tensor) self.tensor_metadata_recv = create_send_metadata(input_tensor)
return input_tensor return input_tensor

View File

@ -4,7 +4,7 @@ import torch.distributed as dist
import colossalai import colossalai
from colossalai.cluster import ProcessGroupMesh from colossalai.cluster import ProcessGroupMesh
from colossalai.pipeline.p2p import P2PDataType, P2PMetadata, PipelineP2PCommunication, TensorMetadata from colossalai.pipeline.p2p import PipelineP2PCommunication, create_send_metadata
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.testing import rerun_if_address_is_in_use, spawn from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
@ -57,19 +57,15 @@ def check_p2p_communication():
p2p.send_forward(data[-(i + 1)]) p2p.send_forward(data[-(i + 1)])
assert recv_obj == data[i] assert recv_obj == data[i]
tensor_metadata = TensorMetadata(
key=None, shape=tensor.shape, dtype=tensor.dtype, requires_grad=tensor.requires_grad
)
comm_metadata = P2PMetadata(data_type=P2PDataType.Tensor, content=tensor_metadata)
if rank == 0: if rank == 0:
recv_obj = p2p.send_forward_recv_backward( recv_obj = p2p.send_forward_recv_backward(
tensor, tensor,
send_metadata=False, send_metadata=False,
metadata_recv=comm_metadata, metadata_recv=create_send_metadata(tensor),
) )
assert recv_obj == tensor assert recv_obj == tensor
elif rank == 1: elif rank == 1:
recv_obj = p2p.recv_forward(metadata_recv=comm_metadata) recv_obj = p2p.recv_forward(metadata_recv=create_send_metadata(tensor))
assert recv_obj == tensor assert recv_obj == tensor
p2p.send_backward(tensor, send_metadata=False) p2p.send_backward(tensor, send_metadata=False)