diff --git a/colossalai/pipeline/p2p.py b/colossalai/pipeline/p2p.py index d32ff501f..5588aa578 100644 --- a/colossalai/pipeline/p2p.py +++ b/colossalai/pipeline/p2p.py @@ -5,20 +5,17 @@ import io import pickle import re from collections import namedtuple -from dataclasses import dataclass -from enum import Enum -from typing import Any, Callable, List, Optional, Union +from typing import Any, Callable, List, Optional, Tuple, Union import torch import torch.distributed as dist from packaging.version import Version from torch.distributed import ProcessGroup from torch.distributed import distributed_c10d as c10d +from torch.utils._pytree import tree_flatten, tree_unflatten from .stage_manager import PipelineStageManager -_unpickler = pickle.Unpickler - def _cuda_safe_tensor_to_object(tensor: torch.Tensor, tensor_size: torch.Size) -> Any: """transform tensor to object with unpickle. @@ -42,7 +39,7 @@ def _cuda_safe_tensor_to_object(tensor: torch.Tensor, tensor_size: torch.Size) - buf = bytes(buf_array) io_bytes = io.BytesIO(buf) - byte_pickler = _unpickler(io_bytes) + byte_pickler = pickle.Unpickler(io_bytes) unpickle = byte_pickler.load() return unpickle @@ -67,7 +64,7 @@ def _broadcast_object_list( c10d._warn_not_in_group("broadcast_object_list") return - is_nccl_backend = check_for_nccl_backend(group) + is_nccl_backend = _check_for_nccl_backend(group) current_device = None if device is not None: @@ -133,45 +130,61 @@ def _broadcast_object_list( object_list[i] = unpickle_object -def check_for_nccl_backend(group): +def _check_for_nccl_backend(group): pg = group or c10d._get_default_group() # Gate PG wrapper check on Gloo availability. if c10d._GLOO_AVAILABLE: - # It is not expected for PG to be wrapped many times, but support it just - # in case + # It is not expected for PG to be wrapped many times, but support it just in case while isinstance(pg, c10d._ProcessGroupWrapper): pg = pg.wrapped_pg return c10d.is_nccl_available() and pg.name() == c10d.Backend.NCCL -def check_device(group): - is_nccl_backend = check_for_nccl_backend(group) - current_device = None - +def _check_device(group): + is_nccl_backend = _check_for_nccl_backend(group) current_device = torch.device("cpu") if is_nccl_backend: current_device = torch.device("cuda", torch.cuda.current_device()) return current_device, is_nccl_backend -TensorMetadata = namedtuple("TensorMetadata", ["key", "shape", "dtype", "requires_grad"]) +TensorMetadata = namedtuple("TensorMetadata", ["shape", "dtype", "requires_grad"]) +P2PMetadata = namedtuple("P2PMetadata", ["tree_spec", "tensor_metadata", "non_tensor_obj_idx", "non_tensor_objs"]) -class P2PDataType(Enum): - Serialization = 0 - Tensor = 1 - List = 2 - Dict = 3 +def create_send_metadata( + object: Any, strict: bool = True, return_tensor: bool = False +) -> Union[P2PMetadata, Tuple[P2PMetadata, List[torch.Tensor]]]: + """ + Args: + object (Any): object needed to be sent + strict (bool, optional): whether to check if the object is supported for fast send + return_tensor (bool, optional): whether to return tensor objects + """ + objs, tree_spec = tree_flatten(object) + tensor_metadata, tensor_objs = [], [] + non_tensor_obj_idx, non_tensor_objs = [], [] + for idx, obj in enumerate(objs): + if isinstance(obj, torch.Tensor): + tensor_objs.append(obj) + tensor_metadata.append(TensorMetadata(obj.shape, obj.dtype, obj.requires_grad)) + else: + non_tensor_obj_idx.append(idx) + non_tensor_objs.append(obj) + + assert not strict or len(non_tensor_objs) == 0, "Only support tensor for fast send" + metadata = P2PMetadata(tree_spec, tensor_metadata, non_tensor_obj_idx, non_tensor_objs) + return metadata if not return_tensor else (metadata, tensor_objs) -@dataclass -class P2PMetadata: - data_type: P2PDataType - content: Union[List[TensorMetadata], TensorMetadata, Any] - - -def filling_ops_queue(obj: Any, comm_op: Callable, comm_rank: int, ops_queue: List, group: ProcessGroup): +def _filling_ops_queue( + obj: Union[torch.Tensor, List[torch.Tensor]], + comm_op: Callable, + comm_rank: int, + ops_queue: List, + group: ProcessGroup, +): if isinstance(obj, torch.Tensor): obj = obj.contiguous() op_to_add = dist.P2POp(comm_op, obj, comm_rank, group) @@ -179,47 +192,22 @@ def filling_ops_queue(obj: Any, comm_op: Callable, comm_rank: int, ops_queue: Li else: for tensor_to_comm in obj: assert isinstance(tensor_to_comm, torch.Tensor) - filling_ops_queue(tensor_to_comm, comm_op, comm_rank, ops_queue, group) + _filling_ops_queue(tensor_to_comm, comm_op, comm_rank, ops_queue, group) -def create_recv_buffer(p2p_metadata: P2PMetadata, current_device: Any): - if p2p_metadata.data_type == P2PDataType.Tensor: - metadata = p2p_metadata.content +def _create_recv_buffer(tensor_metadata: List[TensorMetadata], current_device) -> List[torch.Tensor]: + buffer_recv = [] + for metadata in tensor_metadata: tensor_recv = torch.empty( metadata.shape, requires_grad=metadata.requires_grad, device=current_device, dtype=metadata.dtype ) - return tensor_recv - elif p2p_metadata.data_type in (P2PDataType.List, P2PDataType.Dict): - buffer_recv = [] - for metadata in p2p_metadata.content: - tensor_recv = torch.empty( - metadata.shape, requires_grad=metadata.requires_grad, device=current_device, dtype=metadata.dtype - ) - buffer_recv.append(tensor_recv) - return buffer_recv - else: - raise ValueError(f"Unknown data_type: {p2p_metadata.data_type}") - - -def create_fast_send_metadata(object: Any) -> P2PMetadata: - assert _check_if_fast_send_available(object) - if isinstance(object, torch.Tensor): - data_type = P2PDataType.Tensor - content = TensorMetadata(None, object.shape, object.dtype, object.requires_grad) - elif isinstance(object, list): - data_type = P2PDataType.List - content = [TensorMetadata(None, v.shape, v.dtype, v.requires_grad) for v in object] - elif isinstance(object, dict): - data_type = P2PDataType.Dict - content = [TensorMetadata(k, v.shape, v.dtype, v.requires_grad) for k, v in object.items()] - else: - raise RuntimeError("Cannot handle object of type {}".format(type(object))) - return P2PMetadata(data_type, content) + buffer_recv.append(tensor_recv) + return buffer_recv def _batch_send_recv_tensor( - send_tensor_list: Optional[Union[torch.Tensor, List[torch.Tensor]]], - recv_tensor_metadata: Optional[P2PMetadata], + send_tensor_list: Optional[List[torch.Tensor]], + recv_tensor_metadata: Optional[List[TensorMetadata]], send_dst: Optional[int], recv_src: Optional[int], send_group: Optional[ProcessGroup], @@ -227,16 +215,16 @@ def _batch_send_recv_tensor( current_device: Any, ) -> Optional[Union[torch.Tensor, List[torch.Tensor]]]: buffer_recv = None - if recv_tensor_metadata is not None and recv_tensor_metadata.data_type != P2PDataType.Serialization: - buffer_recv = create_recv_buffer(recv_tensor_metadata, current_device) + if recv_tensor_metadata is not None: + buffer_recv = _create_recv_buffer(recv_tensor_metadata, current_device) ops = [] if send_dst is not None and send_tensor_list is not None: assert send_group is not None - filling_ops_queue(send_tensor_list, dist.isend, send_dst, ops, send_group) + _filling_ops_queue(send_tensor_list, dist.isend, send_dst, ops, send_group) if recv_src is not None and buffer_recv is not None: assert recv_group is not None - filling_ops_queue(buffer_recv, dist.irecv, recv_src, ops, recv_group) + _filling_ops_queue(buffer_recv, dist.irecv, recv_src, ops, recv_group) if len(ops) > 0: reqs = dist.batch_isend_irecv(ops) @@ -247,13 +235,13 @@ def _batch_send_recv_tensor( # However, the Megatron-LM does synchronization here # https://github.com/microsoft/Megatron-DeepSpeed/blob/ef13d099c2a1609225a4ce4c1a1753cc76dd90a1/megatron/p2p_communication.py#L111-L112 # In case there is potential error, uncomment the following `torch.cuda.synchronize()` - torch.cuda.synchronize() + # torch.cuda.synchronize() return buffer_recv def _send_recv_serialization_object( - object: Any, + object: Optional[P2PMetadata], send_dst: Optional[int], recv_src: Optional[int], send_group: Optional[ProcessGroup], @@ -274,14 +262,14 @@ def _send_recv_serialization_object( send_object_size_tensor = send_object_size_tensor.to(current_device) send_object_tensor = send_object_tensor.to(current_device) - filling_ops_queue(send_object_size_tensor, dist.isend, send_dst, ops, send_group) + _filling_ops_queue(send_object_size_tensor, dist.isend, send_dst, ops, send_group) recv_object_size_tensor = None if recv_src is not None: recv_object_size_tensor = torch.empty(1, dtype=torch.long) if is_nccl_backend: recv_object_size_tensor = recv_object_size_tensor.to(current_device) - filling_ops_queue(recv_object_size_tensor, dist.irecv, recv_src, ops, recv_group) + _filling_ops_queue(recv_object_size_tensor, dist.irecv, recv_src, ops, recv_group) if len(ops) > 0: reqs = dist.batch_isend_irecv(ops) @@ -289,19 +277,19 @@ def _send_recv_serialization_object( req.wait() # See the comment in `_batch_send_recv_tensor` - torch.cuda.synchronize() + # torch.cuda.synchronize() ops = [] if send_dst is not None and send_object_tensor is not None: - filling_ops_queue(send_object_tensor, dist.isend, send_dst, ops, send_group) + _filling_ops_queue(send_object_tensor, dist.isend, send_dst, ops, send_group) recv_object_tensor = None if recv_src is not None and recv_object_size_tensor is not None: recv_object_tensor = torch.empty(recv_object_size_tensor.item(), dtype=torch.uint8) if is_nccl_backend: recv_object_tensor = recv_object_tensor.to(current_device) - filling_ops_queue(recv_object_tensor, dist.irecv, recv_src, ops, recv_group) + _filling_ops_queue(recv_object_tensor, dist.irecv, recv_src, ops, recv_group) if len(ops) > 0: reqs = dist.batch_isend_irecv(ops) @@ -309,7 +297,7 @@ def _send_recv_serialization_object( req.wait() # See the comment in `_batch_send_recv_tensor` - torch.cuda.synchronize() + # torch.cuda.synchronize() if recv_object_tensor is not None and recv_object_size_tensor is not None: recv_object_tensor = recv_object_tensor.type(torch.uint8) @@ -324,18 +312,6 @@ def _send_recv_serialization_object( return unpickle_object -def _check_if_fast_send_available(object: Any) -> bool: - if isinstance(object, torch.Tensor): - return True - elif isinstance(object, list): - is_list_of_tensor = all([isinstance(v, torch.Tensor) for v in object]) - return is_list_of_tensor - elif isinstance(object, dict): - is_dict_of_tensor = all([isinstance(k, str) and isinstance(v, torch.Tensor) for k, v in object.items()]) - return is_dict_of_tensor - return False - - def _communicate( object: Any, send_dst: Optional[int], @@ -361,10 +337,15 @@ def _communicate( assert send_dst is not None or recv_src is not None, "send_dst and recv_src cannot be both None" assert send_dst is None or send_group is not None, "send_group must be specified when send_dst is not None" assert recv_src is None or recv_group is not None, "recv_group must be specified when recv_src is not None" - send_metadata = send_metadata or (object is not None and not _check_if_fast_send_available(object)) assert ( - metadata_recv is None or metadata_recv.data_type != P2PDataType.Serialization - ), "metadata_recv type must not be Serialization" + metadata_recv is None or len(metadata_recv.non_tensor_obj_idx) == 0 + ), "metadata_recv should not contain non-tensor objects" + + metadata_send, tensor_objs = None, None + if object is not None: + # NOTE: if object contains non-tensor objects, we have to send metadata + metadata_send, tensor_objs = create_send_metadata(object, strict=False, return_tensor=True) + send_metadata = send_metadata or len(metadata_send.non_tensor_obj_idx) > 0 # NOTE: send & recv should be atomic operations. However, if we need to send metadata or receive metadata, # we are not able to do that (1. send & recv metadata 2. send & recv). So we need to split the send & recv into two parts in this case. @@ -372,9 +353,13 @@ def _communicate( assert send_prior_fallback is not None, "Priority must be set if fallback happens" if send_prior_fallback: _communicate(object, send_dst=send_dst, recv_src=None, send_group=send_group, send_metadata=send_metadata) - return _communicate(None, send_dst=None, recv_src=recv_src, recv_group=recv_group, metadata_recv=metadata_recv) + return _communicate( + None, send_dst=None, recv_src=recv_src, recv_group=recv_group, metadata_recv=metadata_recv + ) else: - recv_data = _communicate(None, send_dst=None, recv_src=recv_src, recv_group=recv_group, metadata_recv=metadata_recv) + recv_data = _communicate( + None, send_dst=None, recv_src=recv_src, recv_group=recv_group, metadata_recv=metadata_recv + ) _communicate(object, send_dst=send_dst, recv_src=None, send_group=send_group, send_metadata=send_metadata) return recv_data @@ -387,8 +372,8 @@ def _communicate( assert not (send_dst is not None and recv_src is not None) or (not send_metadata and metadata_recv is not None) assert not c10d._rank_not_in_group(send_group) and not c10d._rank_not_in_group(recv_group) - current_send_device, is_send_nccl_backend = check_device(send_group) - current_recv_device, is_recv_nccl_backend = check_device(recv_group) + current_send_device, is_send_nccl_backend = _check_device(send_group) + current_recv_device, is_recv_nccl_backend = _check_device(recv_group) is_nccl_backend = is_send_nccl_backend and is_recv_nccl_backend @@ -396,14 +381,6 @@ def _communicate( current_device = current_send_device if (send_dst is not None and send_metadata) or (recv_src is not None and metadata_recv is None): - metadata_send = None - if send_dst is not None and send_metadata: - can_fast_send = _check_if_fast_send_available(object) and is_nccl_backend - if not can_fast_send: - metadata_send = P2PMetadata(P2PDataType.Serialization, object) - else: - metadata_send = create_fast_send_metadata(object) - # Send and receive metadata _metadata_recv = _send_recv_serialization_object( object=metadata_send, @@ -417,31 +394,26 @@ def _communicate( assert metadata_recv is None or _metadata_recv is None metadata_recv = _metadata_recv if metadata_recv is None else metadata_recv - send_tensor_list = None - if isinstance(object, torch.Tensor): - send_tensor_list = object - elif isinstance(object, list): - send_tensor_list = object - elif isinstance(object, dict): - send_tensor_list = list(object.values()) - # Send and receive data - recv_buffer = _batch_send_recv_tensor( - send_tensor_list, metadata_recv, send_dst, recv_src, send_group, recv_group, current_device + recv_tensor_metadata = None if metadata_recv is None else metadata_recv.tensor_metadata + recv_tensor_objs = _batch_send_recv_tensor( + tensor_objs, recv_tensor_metadata, send_dst, recv_src, send_group, recv_group, current_device ) if metadata_recv is not None: assert isinstance(metadata_recv, P2PMetadata) - if metadata_recv.data_type == P2PDataType.Serialization: - return metadata_recv.content - else: - assert recv_buffer is not None - if metadata_recv.data_type in [P2PDataType.Tensor, P2PDataType.List]: - return recv_buffer - elif metadata_recv.data_type == P2PDataType.Dict: - return {k: v for k, v in zip([m.key for m in metadata_recv.content], recv_buffer)} - else: - raise ValueError("Unknown data type {}".format(metadata_recv.data_type)) + tree_spec = metadata_recv.tree_spec + non_tensor_obj_idx = metadata_recv.non_tensor_obj_idx + non_tensor_objs = metadata_recv.non_tensor_objs + + if recv_tensor_objs is None: + recv_tensor_objs = [] + + for idx in non_tensor_obj_idx: + recv_tensor_objs.insert(idx, non_tensor_objs.pop(0)) + recv_object = tree_unflatten(recv_tensor_objs, tree_spec) + + return recv_object def _send_object(object: Any, src: int, dst: int, group: ProcessGroup, **kwargs) -> None: diff --git a/colossalai/pipeline/schedule/interleaved_pp.py b/colossalai/pipeline/schedule/interleaved_pp.py index aa18a8520..0a01a1e78 100644 --- a/colossalai/pipeline/schedule/interleaved_pp.py +++ b/colossalai/pipeline/schedule/interleaved_pp.py @@ -7,7 +7,7 @@ from torch.nn import Module, ModuleList from torch.utils._pytree import tree_map from colossalai.interface import OptimizerWrapper -from colossalai.pipeline.p2p import PipelineP2PCommunication, create_fast_send_metadata +from colossalai.pipeline.p2p import PipelineP2PCommunication, create_send_metadata from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.utils.device import get_current_device @@ -130,7 +130,7 @@ class InterleavedSchedule(PipelineSchedule): if not self.stage_manager.is_first_stage(): input_tensor = self.comm.recv_forward(prev_rank, metadata_recv=self.tensor_metadata_recv) if self.enable_metadata_cache and self.tensor_metadata_recv is None: - self.tensor_metadata_recv = create_fast_send_metadata(input_tensor) + self.tensor_metadata_recv = create_send_metadata(input_tensor) return input_tensor @@ -149,7 +149,7 @@ class InterleavedSchedule(PipelineSchedule): if not self.stage_manager.is_last_stage(): output_tensor_grad = self.comm.recv_backward(next_rank, metadata_recv=self.grad_metadata_recv) if self.enable_metadata_cache and self.grad_metadata_recv is None: - self.grad_metadata_recv = create_fast_send_metadata(output_tensor_grad) + self.grad_metadata_recv = create_send_metadata(output_tensor_grad) return output_tensor_grad @@ -206,7 +206,7 @@ class InterleavedSchedule(PipelineSchedule): ) self.send_tensor_metadata = not self.enable_metadata_cache if self.enable_metadata_cache and self.grad_metadata_recv is None: - self.grad_metadata_recv = create_fast_send_metadata(output_tensor_grad) + self.grad_metadata_recv = create_send_metadata(output_tensor_grad) return output_tensor_grad # send only or recv only @@ -238,7 +238,7 @@ class InterleavedSchedule(PipelineSchedule): ) self.send_grad_metadata = not self.enable_metadata_cache if self.enable_metadata_cache and self.tensor_metadata_recv is None: - self.tensor_metadata_recv = create_fast_send_metadata(input_tensor) + self.tensor_metadata_recv = create_send_metadata(input_tensor) return input_tensor # send only or recv only diff --git a/colossalai/pipeline/schedule/one_f_one_b.py b/colossalai/pipeline/schedule/one_f_one_b.py index be60dcc74..cb078b25f 100644 --- a/colossalai/pipeline/schedule/one_f_one_b.py +++ b/colossalai/pipeline/schedule/one_f_one_b.py @@ -7,7 +7,7 @@ from torch.nn import Module from torch.utils._pytree import tree_map from colossalai.interface import ModelWrapper, OptimizerWrapper -from colossalai.pipeline.p2p import PipelineP2PCommunication, create_fast_send_metadata +from colossalai.pipeline.p2p import PipelineP2PCommunication, create_send_metadata from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.utils.device import get_current_device @@ -121,7 +121,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): if not self.stage_manager.is_first_stage(): input_tensor = self.comm.recv_forward(prev_rank, metadata_recv=self.tensor_metadata_recv) if self.enable_metadata_cache and self.tensor_metadata_recv is None: - self.tensor_metadata_recv = create_fast_send_metadata(input_tensor) + self.tensor_metadata_recv = create_send_metadata(input_tensor) return input_tensor @@ -138,7 +138,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): if not self.stage_manager.is_last_stage(): output_tensor_grad = self.comm.recv_backward(next_rank, metadata_recv=self.grad_metadata_recv) if self.enable_metadata_cache and self.grad_metadata_recv is None: - self.grad_metadata_recv = create_fast_send_metadata(output_tensor_grad) + self.grad_metadata_recv = create_send_metadata(output_tensor_grad) return output_tensor_grad @@ -188,7 +188,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): ) self.send_tensor_metadata = not self.enable_metadata_cache if self.enable_metadata_cache and self.grad_metadata_recv is None: - self.grad_metadata_recv = create_fast_send_metadata(output_tensor_grad) + self.grad_metadata_recv = create_send_metadata(output_tensor_grad) return output_tensor_grad @@ -214,7 +214,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): ) self.send_grad_metadata = not self.enable_metadata_cache if self.enable_metadata_cache and self.tensor_metadata_recv is None: - self.tensor_metadata_recv = create_fast_send_metadata(input_tensor) + self.tensor_metadata_recv = create_send_metadata(input_tensor) return input_tensor diff --git a/tests/test_pipeline/test_p2p_communication.py b/tests/test_pipeline/test_p2p_communication.py index 1c859fd93..caf6e6bbb 100644 --- a/tests/test_pipeline/test_p2p_communication.py +++ b/tests/test_pipeline/test_p2p_communication.py @@ -4,7 +4,7 @@ import torch.distributed as dist import colossalai from colossalai.cluster import ProcessGroupMesh -from colossalai.pipeline.p2p import P2PDataType, P2PMetadata, PipelineP2PCommunication, TensorMetadata +from colossalai.pipeline.p2p import PipelineP2PCommunication, create_send_metadata from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.testing import rerun_if_address_is_in_use, spawn from colossalai.utils import get_current_device @@ -57,19 +57,15 @@ def check_p2p_communication(): p2p.send_forward(data[-(i + 1)]) assert recv_obj == data[i] - tensor_metadata = TensorMetadata( - key=None, shape=tensor.shape, dtype=tensor.dtype, requires_grad=tensor.requires_grad - ) - comm_metadata = P2PMetadata(data_type=P2PDataType.Tensor, content=tensor_metadata) if rank == 0: recv_obj = p2p.send_forward_recv_backward( tensor, send_metadata=False, - metadata_recv=comm_metadata, + metadata_recv=create_send_metadata(tensor), ) assert recv_obj == tensor elif rank == 1: - recv_obj = p2p.recv_forward(metadata_recv=comm_metadata) + recv_obj = p2p.recv_forward(metadata_recv=create_send_metadata(tensor)) assert recv_obj == tensor p2p.send_backward(tensor, send_metadata=False)