mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 01:28:31 +00:00
Merge branch 'main' into sync/npu
This commit is contained in:
@@ -4,23 +4,20 @@
|
||||
import io
|
||||
import pickle
|
||||
import re
|
||||
from typing import Any, List, Optional, Union
|
||||
from collections import namedtuple
|
||||
from typing import Any, Callable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from packaging.version import Version
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.distributed import distributed_c10d as c10d
|
||||
from torch.utils._pytree import tree_flatten, tree_unflatten
|
||||
|
||||
from .stage_manager import PipelineStageManager
|
||||
|
||||
_unpickler = pickle.Unpickler
|
||||
|
||||
|
||||
def _cuda_safe_tensor_to_object(tensor: torch.Tensor, tensor_size: torch.Size) -> object:
|
||||
def _cuda_safe_tensor_to_object(tensor: torch.Tensor, tensor_size: torch.Size) -> Any:
|
||||
"""transform tensor to object with unpickle.
|
||||
Info of the device in bytes stream will be modified into current device before unpickling
|
||||
|
||||
@@ -42,27 +39,13 @@ def _cuda_safe_tensor_to_object(tensor: torch.Tensor, tensor_size: torch.Size) -
|
||||
buf = bytes(buf_array)
|
||||
|
||||
io_bytes = io.BytesIO(buf)
|
||||
byte_pickler = _unpickler(io_bytes)
|
||||
byte_pickler = pickle.Unpickler(io_bytes)
|
||||
unpickle = byte_pickler.load()
|
||||
|
||||
return unpickle
|
||||
|
||||
|
||||
def check_for_nccl_backend(group):
|
||||
pg = group or c10d._get_default_group()
|
||||
# Gate PG wrapper check on Gloo availability.
|
||||
if c10d._GLOO_AVAILABLE:
|
||||
# It is not expected for PG to be wrapped many times, but support it just
|
||||
# in case
|
||||
while isinstance(pg, c10d._ProcessGroupWrapper):
|
||||
pg = pg.wrapped_pg
|
||||
|
||||
return (
|
||||
c10d.is_nccl_available() and
|
||||
pg.name() == c10d.Backend.NCCL
|
||||
)
|
||||
|
||||
|
||||
# NOTE: FIXME: NPU DOES NOT support isend nor irecv, so broadcast is kept for future use
|
||||
def _broadcast_object_list(
|
||||
object_list: List[Any], src: int, group: ProcessGroup, device: Optional[Union[torch.device, str, int]] = None
|
||||
):
|
||||
@@ -70,20 +53,18 @@ def _broadcast_object_list(
|
||||
The only difference is that object will be move to correct device after unpickled.
|
||||
If local_rank = src, then object list will be sent to rank src. Otherwise, object list will
|
||||
be updated with data sent from rank src.
|
||||
|
||||
Args:
|
||||
object_list (List[Any]): list of object to broadcast
|
||||
src (int): source rank to broadcast
|
||||
dst (int): dst rank to broadcast
|
||||
device (:class:`torch.device`): device to do broadcast. current device in default
|
||||
|
||||
"""
|
||||
|
||||
if c10d._rank_not_in_group(group):
|
||||
c10d._warn_not_in_group("broadcast_object_list")
|
||||
return
|
||||
|
||||
is_nccl_backend = check_for_nccl_backend(group)
|
||||
is_nccl_backend = _check_for_nccl_backend(group)
|
||||
current_device = None
|
||||
|
||||
if device is not None:
|
||||
@@ -131,7 +112,7 @@ def _broadcast_object_list(
|
||||
|
||||
if my_rank != src:
|
||||
for i, obj_size in enumerate(object_sizes_tensor):
|
||||
obj_view = object_tensor[offset: offset + obj_size]
|
||||
obj_view = object_tensor[offset : offset + obj_size]
|
||||
obj_view = obj_view.type(torch.uint8)
|
||||
if obj_view.device != torch.device("cpu"):
|
||||
obj_view = obj_view.cpu()
|
||||
@@ -149,80 +130,107 @@ def _broadcast_object_list(
|
||||
object_list[i] = unpickle_object
|
||||
|
||||
|
||||
def check_device(group):
|
||||
is_nccl_backend = check_for_nccl_backend(group)
|
||||
current_device = None
|
||||
def _check_for_nccl_backend(group):
|
||||
pg = group or c10d._get_default_group()
|
||||
# Gate PG wrapper check on Gloo availability.
|
||||
if c10d._GLOO_AVAILABLE:
|
||||
# It is not expected for PG to be wrapped many times, but support it just in case
|
||||
while isinstance(pg, c10d._ProcessGroupWrapper):
|
||||
pg = pg.wrapped_pg
|
||||
|
||||
return c10d.is_nccl_available() and pg.name() == c10d.Backend.NCCL
|
||||
|
||||
|
||||
def _check_device(group):
|
||||
is_nccl_backend = _check_for_nccl_backend(group)
|
||||
current_device = torch.device("cpu")
|
||||
if is_nccl_backend:
|
||||
current_device = torch.device("cuda", torch.cuda.current_device())
|
||||
return current_device, is_nccl_backend
|
||||
|
||||
|
||||
TensorMetadata = namedtuple('TensorMetadata', ['key', 'shape', 'dtype', 'requires_grad'])
|
||||
TensorMetadata = namedtuple("TensorMetadata", ["shape", "dtype", "requires_grad"])
|
||||
P2PMetadata = namedtuple("P2PMetadata", ["tree_spec", "tensor_metadata", "non_tensor_obj_idx", "non_tensor_objs"])
|
||||
|
||||
|
||||
class P2PDataType(Enum):
|
||||
serialization = 0
|
||||
tensor = 1
|
||||
list = 2
|
||||
dict = 3
|
||||
def create_send_metadata(
|
||||
object: Any, strict: bool = True, return_tensor: bool = False
|
||||
) -> Union[P2PMetadata, Tuple[P2PMetadata, List[torch.Tensor]]]:
|
||||
"""
|
||||
Args:
|
||||
object (Any): object needed to be sent
|
||||
strict (bool, optional): whether to check if the object is supported for fast send
|
||||
return_tensor (bool, optional): whether to return tensor objects
|
||||
"""
|
||||
objs, tree_spec = tree_flatten(object)
|
||||
tensor_metadata, tensor_objs = [], []
|
||||
non_tensor_obj_idx, non_tensor_objs = [], []
|
||||
for idx, obj in enumerate(objs):
|
||||
if isinstance(obj, torch.Tensor):
|
||||
tensor_objs.append(obj)
|
||||
tensor_metadata.append(TensorMetadata(obj.shape, obj.dtype, obj.requires_grad))
|
||||
else:
|
||||
non_tensor_obj_idx.append(idx)
|
||||
non_tensor_objs.append(obj)
|
||||
|
||||
assert not strict or len(non_tensor_objs) == 0, "Only support tensor for fast send"
|
||||
metadata = P2PMetadata(tree_spec, tensor_metadata, non_tensor_obj_idx, non_tensor_objs)
|
||||
return metadata if not return_tensor else (metadata, tensor_objs)
|
||||
|
||||
|
||||
@dataclass
|
||||
class P2PMetadata:
|
||||
data_type: P2PDataType
|
||||
content: Union[List[TensorMetadata], TensorMetadata, Any]
|
||||
|
||||
|
||||
def filling_ops_queue(obj, comm_op, comm_rank, ops_queue, group):
|
||||
def _filling_ops_queue(
|
||||
obj: Union[torch.Tensor, List[torch.Tensor]],
|
||||
comm_op: Callable,
|
||||
comm_rank: int,
|
||||
ops_queue: List,
|
||||
group: ProcessGroup,
|
||||
):
|
||||
if isinstance(obj, torch.Tensor):
|
||||
obj = obj.contiguous()
|
||||
op_to_add = dist.P2POp(comm_op, obj, comm_rank, group)
|
||||
ops_queue.append(op_to_add)
|
||||
else:
|
||||
for tensor_to_comm in obj:
|
||||
tensor_to_comm = tensor_to_comm.contiguous()
|
||||
op_to_add = dist.P2POp(comm_op, tensor_to_comm, comm_rank, group)
|
||||
ops_queue.append(op_to_add)
|
||||
assert isinstance(tensor_to_comm, torch.Tensor)
|
||||
_filling_ops_queue(tensor_to_comm, comm_op, comm_rank, ops_queue, group)
|
||||
|
||||
|
||||
def create_recv_buffer(p2p_metadata: P2PMetadata, current_device):
|
||||
if p2p_metadata.data_type == P2PDataType.tensor:
|
||||
metadata = p2p_metadata.content
|
||||
tensor_recv = torch.empty(metadata.shape, requires_grad=metadata.requires_grad, device=current_device, dtype=metadata.dtype)
|
||||
return tensor_recv
|
||||
elif p2p_metadata.data_type in (P2PDataType.list, P2PDataType.dict):
|
||||
buffer_recv = []
|
||||
for metadata in p2p_metadata.content:
|
||||
tensor_recv = torch.empty(metadata.shape, requires_grad=metadata.requires_grad, device=current_device, dtype=metadata.dtype)
|
||||
buffer_recv.append(tensor_recv)
|
||||
return buffer_recv
|
||||
else:
|
||||
raise ValueError(f"Unknown data_type: {p2p_metadata.data_type}")
|
||||
def _create_recv_buffer(tensor_metadata: List[TensorMetadata], current_device) -> List[torch.Tensor]:
|
||||
buffer_recv = []
|
||||
for metadata in tensor_metadata:
|
||||
tensor_recv = torch.empty(
|
||||
metadata.shape, requires_grad=metadata.requires_grad, device=current_device, dtype=metadata.dtype
|
||||
)
|
||||
buffer_recv.append(tensor_recv)
|
||||
return buffer_recv
|
||||
|
||||
|
||||
def _batch_send_recv_tensor(send_tensor_list, recv_tensor_metadata, send_dst, recv_src, send_group, recv_group, current_device):
|
||||
def _batch_send_recv_tensor(
|
||||
send_tensor_list: Optional[List[torch.Tensor]],
|
||||
recv_tensor_metadata: Optional[List[TensorMetadata]],
|
||||
send_dst: Optional[int],
|
||||
recv_src: Optional[int],
|
||||
send_group: Optional[ProcessGroup],
|
||||
recv_group: Optional[ProcessGroup],
|
||||
current_device: Any,
|
||||
) -> Optional[Union[torch.Tensor, List[torch.Tensor]]]:
|
||||
buffer_recv = None
|
||||
if recv_tensor_metadata is not None:
|
||||
buffer_recv = create_recv_buffer(recv_tensor_metadata, current_device)
|
||||
buffer_recv = _create_recv_buffer(recv_tensor_metadata, current_device)
|
||||
|
||||
ops = []
|
||||
|
||||
if send_dst is not None:
|
||||
filling_ops_queue(send_tensor_list, dist.isend, send_dst, ops, send_group)
|
||||
|
||||
if recv_src is not None:
|
||||
assert buffer_recv is not None
|
||||
filling_ops_queue(buffer_recv, dist.irecv, recv_src, ops, recv_group)
|
||||
if send_dst is not None and send_tensor_list is not None:
|
||||
assert send_group is not None
|
||||
_filling_ops_queue(send_tensor_list, dist.isend, send_dst, ops, send_group)
|
||||
if recv_src is not None and buffer_recv is not None:
|
||||
assert recv_group is not None
|
||||
_filling_ops_queue(buffer_recv, dist.irecv, recv_src, ops, recv_group)
|
||||
|
||||
if len(ops) > 0:
|
||||
reqs = dist.batch_isend_irecv(ops)
|
||||
for req in reqs:
|
||||
req.wait()
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Remove synchronization according to Pytorch's documentation
|
||||
# However, the Megatron-LM does synchronization here
|
||||
# https://github.com/microsoft/Megatron-DeepSpeed/blob/ef13d099c2a1609225a4ce4c1a1753cc76dd90a1/megatron/p2p_communication.py#L111-L112
|
||||
@@ -233,12 +241,16 @@ def _batch_send_recv_tensor(send_tensor_list, recv_tensor_metadata, send_dst, re
|
||||
|
||||
|
||||
def _send_recv_serialization_object(
|
||||
object: Any,
|
||||
send_dst: Optional[int], recv_src: Optional[int],
|
||||
send_group: Optional[ProcessGroup], recv_group: Optional[ProcessGroup],
|
||||
current_device,
|
||||
is_nccl_backend):
|
||||
object: Optional[P2PMetadata],
|
||||
send_dst: Optional[int],
|
||||
recv_src: Optional[int],
|
||||
send_group: Optional[ProcessGroup],
|
||||
recv_group: Optional[ProcessGroup],
|
||||
current_device: Any,
|
||||
is_nccl_backend: bool,
|
||||
) -> Optional[P2PMetadata]:
|
||||
ops = []
|
||||
|
||||
send_object_tensor = None
|
||||
if object is not None and send_dst is not None:
|
||||
if Version(torch.__version__) >= Version("1.13.0"):
|
||||
@@ -250,44 +262,40 @@ def _send_recv_serialization_object(
|
||||
send_object_size_tensor = send_object_size_tensor.to(current_device)
|
||||
send_object_tensor = send_object_tensor.to(current_device)
|
||||
|
||||
filling_ops_queue(send_object_size_tensor, dist.isend, send_dst, ops, send_group)
|
||||
_filling_ops_queue(send_object_size_tensor, dist.isend, send_dst, ops, send_group)
|
||||
|
||||
recv_object_size_tensor = None
|
||||
if recv_src is not None:
|
||||
recv_object_size_tensor = torch.empty(1, dtype=torch.long)
|
||||
if is_nccl_backend:
|
||||
recv_object_size_tensor = recv_object_size_tensor.to(current_device)
|
||||
filling_ops_queue(recv_object_size_tensor, dist.irecv, recv_src, ops, recv_group)
|
||||
_filling_ops_queue(recv_object_size_tensor, dist.irecv, recv_src, ops, recv_group)
|
||||
|
||||
if len(ops) > 0:
|
||||
reqs = dist.batch_isend_irecv(ops)
|
||||
for req in reqs:
|
||||
req.wait()
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# See the comment in `_batch_send_recv_tensor`
|
||||
# torch.cuda.synchronize()
|
||||
|
||||
ops = []
|
||||
|
||||
if send_dst is not None and send_object_tensor is not None:
|
||||
filling_ops_queue(send_object_tensor, dist.isend, send_dst, ops, send_group)
|
||||
_filling_ops_queue(send_object_tensor, dist.isend, send_dst, ops, send_group)
|
||||
|
||||
recv_object_tensor = None
|
||||
if recv_src is not None and recv_object_size_tensor is not None:
|
||||
recv_object_tensor = torch.empty(recv_object_size_tensor.item(), dtype=torch.uint8)
|
||||
if is_nccl_backend:
|
||||
recv_object_tensor = recv_object_tensor.to(current_device)
|
||||
filling_ops_queue(recv_object_tensor, dist.irecv, recv_src, ops, recv_group)
|
||||
_filling_ops_queue(recv_object_tensor, dist.irecv, recv_src, ops, recv_group)
|
||||
|
||||
if len(ops) > 0:
|
||||
reqs = dist.batch_isend_irecv(ops)
|
||||
for req in reqs:
|
||||
req.wait()
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# See the comment in `_batch_send_recv_tensor`
|
||||
# torch.cuda.synchronize()
|
||||
|
||||
@@ -296,112 +304,119 @@ def _send_recv_serialization_object(
|
||||
if recv_object_tensor.device != torch.device("cpu"):
|
||||
recv_object_tensor = recv_object_tensor.cpu()
|
||||
|
||||
unpickle_object = _cuda_safe_tensor_to_object(
|
||||
recv_object_tensor, recv_object_size_tensor.item())
|
||||
unpickle_object = _cuda_safe_tensor_to_object(recv_object_tensor, recv_object_size_tensor.item())
|
||||
|
||||
if (
|
||||
isinstance(unpickle_object, torch.Tensor)
|
||||
and unpickle_object.device.index != torch.cuda.current_device()
|
||||
):
|
||||
if isinstance(unpickle_object, torch.Tensor) and unpickle_object.device.index != torch.cuda.current_device():
|
||||
unpickle_object = unpickle_object.cuda()
|
||||
|
||||
return unpickle_object
|
||||
|
||||
|
||||
def _check_if_fast_send_available(object):
|
||||
if type(object) is torch.Tensor:
|
||||
return True
|
||||
elif type(object) is list:
|
||||
is_list_of_tensor = all([type(v) is torch.Tensor for v in object])
|
||||
return is_list_of_tensor
|
||||
elif type(object) is dict:
|
||||
is_dict_of_tensor = all([type(k) is str and type(
|
||||
v) is torch.Tensor for k, v in object.items()])
|
||||
|
||||
return is_dict_of_tensor
|
||||
return False
|
||||
|
||||
|
||||
def _communicate(
|
||||
object,
|
||||
object: Any,
|
||||
send_dst: Optional[int],
|
||||
recv_src: Optional[int],
|
||||
send_group: Optional[ProcessGroup] = None,
|
||||
recv_group: Optional[ProcessGroup] = None,
|
||||
send_metadata: bool = True,
|
||||
metadata_recv: Optional[P2PMetadata] = None,
|
||||
send_prior_fallback: Optional[bool] = None,
|
||||
) -> Any:
|
||||
if c10d._rank_not_in_group(send_group) or c10d._rank_not_in_group(recv_group):
|
||||
c10d._warn_not_in_group("_communicate")
|
||||
return
|
||||
"""
|
||||
Send and receive object from send_dst and recv_src respectively
|
||||
|
||||
current_send_device, is_send_nccl_backend = check_device(send_group)
|
||||
current_recv_device, is_recv_nccl_backend = check_device(recv_group)
|
||||
Args:
|
||||
object (Any): object needed to be sent
|
||||
send_dst (int): rank of the destination
|
||||
recv_src (int): rank of the source
|
||||
send_group (ProcessGroup, optional): process group of sender
|
||||
recv_group (ProcessGroup, optional): process group of receiver
|
||||
send_metadata (bool, optional): whether to send metadata
|
||||
metadata_recv (P2PMetadata, optional): metadata of the object to be received
|
||||
"""
|
||||
assert send_dst is not None or recv_src is not None, "send_dst and recv_src cannot be both None"
|
||||
assert send_dst is None or send_group is not None, "send_group must be specified when send_dst is not None"
|
||||
assert recv_src is None or recv_group is not None, "recv_group must be specified when recv_src is not None"
|
||||
assert (
|
||||
metadata_recv is None or len(metadata_recv.non_tensor_obj_idx) == 0
|
||||
), "metadata_recv should not contain non-tensor objects"
|
||||
|
||||
metadata_send, tensor_objs = None, None
|
||||
if object is not None:
|
||||
# NOTE: if object contains non-tensor objects, we have to send metadata
|
||||
metadata_send, tensor_objs = create_send_metadata(object, strict=False, return_tensor=True)
|
||||
send_metadata = send_metadata or len(metadata_send.non_tensor_obj_idx) > 0
|
||||
|
||||
# NOTE: send & recv should be atomic operations. However, if we need to send metadata or receive metadata,
|
||||
# we are not able to do that (1. send & recv metadata 2. send & recv). So we need to split the send & recv into two parts in this case.
|
||||
if (send_dst is not None and recv_src is not None) and (send_metadata or metadata_recv is None):
|
||||
assert send_prior_fallback is not None, "Priority must be set if fallback happens"
|
||||
if send_prior_fallback:
|
||||
_communicate(object, send_dst=send_dst, recv_src=None, send_group=send_group, send_metadata=send_metadata)
|
||||
return _communicate(
|
||||
None, send_dst=None, recv_src=recv_src, recv_group=recv_group, metadata_recv=metadata_recv
|
||||
)
|
||||
else:
|
||||
recv_data = _communicate(
|
||||
None, send_dst=None, recv_src=recv_src, recv_group=recv_group, metadata_recv=metadata_recv
|
||||
)
|
||||
_communicate(object, send_dst=send_dst, recv_src=None, send_group=send_group, send_metadata=send_metadata)
|
||||
return recv_data
|
||||
|
||||
# NOTE: only the following 5 cases are valid:
|
||||
# 1. send() [needs extra metadata] and no recv()
|
||||
# 2. recv() [needs extra metadata] and no send()
|
||||
# 3. neither send() nor recv() need extra metadata
|
||||
assert not (send_dst is not None and send_metadata) or recv_src is None
|
||||
assert not (recv_src is not None and metadata_recv is None) or send_dst is None
|
||||
assert not (send_dst is not None and recv_src is not None) or (not send_metadata and metadata_recv is not None)
|
||||
assert not c10d._rank_not_in_group(send_group) and not c10d._rank_not_in_group(recv_group)
|
||||
|
||||
current_send_device, is_send_nccl_backend = _check_device(send_group)
|
||||
current_recv_device, is_recv_nccl_backend = _check_device(recv_group)
|
||||
|
||||
is_nccl_backend = is_send_nccl_backend and is_recv_nccl_backend
|
||||
|
||||
assert current_send_device == current_recv_device
|
||||
current_device = current_send_device
|
||||
|
||||
assert (send_dst is not None) or (recv_src is not None)
|
||||
if (send_dst is not None and send_metadata) or (recv_src is not None and metadata_recv is None):
|
||||
# Send and receive metadata
|
||||
_metadata_recv = _send_recv_serialization_object(
|
||||
object=metadata_send,
|
||||
send_dst=send_dst if send_metadata else None,
|
||||
recv_src=recv_src if metadata_recv is None else None,
|
||||
send_group=send_group if send_metadata else None,
|
||||
recv_group=recv_group if metadata_recv is None else None,
|
||||
current_device=current_device,
|
||||
is_nccl_backend=is_nccl_backend,
|
||||
)
|
||||
assert metadata_recv is None or _metadata_recv is None
|
||||
metadata_recv = _metadata_recv if metadata_recv is None else metadata_recv
|
||||
|
||||
can_fast_send = False
|
||||
send_metadata = None
|
||||
if send_dst is not None:
|
||||
can_fast_send = _check_if_fast_send_available(object) and is_nccl_backend
|
||||
if not can_fast_send:
|
||||
send_metadata = P2PMetadata(P2PDataType.serialization, object)
|
||||
else:
|
||||
if type(object) is torch.Tensor:
|
||||
data_type = P2PDataType.tensor
|
||||
content = TensorMetadata(None, object.shape, object.dtype, object.requires_grad)
|
||||
elif type(object) is list:
|
||||
data_type = P2PDataType.list
|
||||
content = []
|
||||
for v in object:
|
||||
content.append(TensorMetadata(None, v.shape, v.dtype, v.requires_grad))
|
||||
elif type(object) is dict:
|
||||
data_type = P2PDataType.dict
|
||||
content = []
|
||||
for k, v in object.items():
|
||||
content.append(TensorMetadata(k, v.shape, v.dtype, v.requires_grad))
|
||||
else:
|
||||
raise ValueError('Cannot send object of type {}'.format(type(object)))
|
||||
send_metadata = P2PMetadata(data_type, content)
|
||||
# Send and receive data
|
||||
recv_tensor_metadata = None if metadata_recv is None else metadata_recv.tensor_metadata
|
||||
recv_tensor_objs = _batch_send_recv_tensor(
|
||||
tensor_objs, recv_tensor_metadata, send_dst, recv_src, send_group, recv_group, current_device
|
||||
)
|
||||
|
||||
recv_metadata = _send_recv_serialization_object(send_metadata, send_dst, recv_src, send_group, recv_group, current_device, is_nccl_backend)
|
||||
if recv_metadata is not None:
|
||||
assert type(recv_metadata) is P2PMetadata
|
||||
if recv_metadata.data_type == P2PDataType.serialization:
|
||||
return recv_metadata.content
|
||||
if not can_fast_send and send_dst is not None:
|
||||
return
|
||||
if metadata_recv is not None:
|
||||
assert isinstance(metadata_recv, P2PMetadata)
|
||||
tree_spec = metadata_recv.tree_spec
|
||||
non_tensor_obj_idx = metadata_recv.non_tensor_obj_idx
|
||||
non_tensor_objs = metadata_recv.non_tensor_objs
|
||||
|
||||
send_tensor_list = None
|
||||
if type(object) is torch.Tensor:
|
||||
send_tensor_list = object
|
||||
elif type(object) is list:
|
||||
send_tensor_list = object
|
||||
elif type(object) is dict:
|
||||
send_tensor_list = list(object.values())
|
||||
if recv_tensor_objs is None:
|
||||
recv_tensor_objs = []
|
||||
|
||||
recv_buffer = _batch_send_recv_tensor(send_tensor_list, recv_metadata, send_dst, recv_src, send_group, recv_group, current_device)
|
||||
for idx in non_tensor_obj_idx:
|
||||
recv_tensor_objs.insert(idx, non_tensor_objs.pop(0))
|
||||
recv_object = tree_unflatten(recv_tensor_objs, tree_spec)
|
||||
|
||||
if recv_metadata is not None:
|
||||
assert recv_buffer is not None
|
||||
if recv_metadata.data_type in [P2PDataType.tensor, P2PDataType.list]:
|
||||
return recv_buffer
|
||||
elif recv_metadata.data_type == P2PDataType.dict:
|
||||
return {
|
||||
k: v
|
||||
for k, v in zip(
|
||||
[m.key for m in recv_metadata.content],
|
||||
recv_buffer,
|
||||
)
|
||||
}
|
||||
else:
|
||||
raise ValueError('Unknown data type {}'.format(recv_metadata.data_type))
|
||||
return recv_object
|
||||
|
||||
|
||||
def _send_object(object: Any, src: int, dst: int, group: ProcessGroup) -> None:
|
||||
def _send_object(object: Any, src: int, dst: int, group: ProcessGroup, **kwargs) -> None:
|
||||
"""send anything to dst rank
|
||||
|
||||
Args:
|
||||
@@ -411,10 +426,10 @@ def _send_object(object: Any, src: int, dst: int, group: ProcessGroup) -> None:
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
_communicate(object, send_dst=dst, recv_src=None, send_group=group)
|
||||
_communicate(object, send_dst=dst, recv_src=None, send_group=group, **kwargs)
|
||||
|
||||
|
||||
def _recv_object(src: int, dst: int, group: ProcessGroup) -> Any:
|
||||
def _recv_object(src: int, dst: int, group: ProcessGroup, **kwargs) -> Any:
|
||||
"""recv anything from src
|
||||
|
||||
Args:
|
||||
@@ -423,7 +438,7 @@ def _recv_object(src: int, dst: int, group: ProcessGroup) -> Any:
|
||||
Returns:
|
||||
Any: Object received from src.
|
||||
"""
|
||||
return _communicate(None, send_dst=None, recv_src=src, recv_group=group)
|
||||
return _communicate(None, send_dst=None, recv_src=src, recv_group=group, **kwargs)
|
||||
|
||||
|
||||
def _p2p_comm(
|
||||
@@ -436,7 +451,7 @@ def _p2p_comm(
|
||||
"""
|
||||
Send and recv tensor using P2P communication, used when pipeline size is 2 to solve the race communication.
|
||||
|
||||
Agrs:
|
||||
Args:
|
||||
tensor_send_next (torch.Tensor): tensor to be sent to next stage
|
||||
recv_prev (bool): whether to receive tensor from previous stage
|
||||
peer (int): rank of the peer
|
||||
@@ -467,7 +482,6 @@ def _p2p_comm(
|
||||
group=group,
|
||||
)
|
||||
ops.append(recv_prev_op)
|
||||
|
||||
if len(ops) > 0:
|
||||
reqs = dist.batch_isend_irecv(ops)
|
||||
for req in reqs:
|
||||
@@ -490,7 +504,6 @@ def _p2p_comm(
|
||||
group=group,
|
||||
)
|
||||
ops.append(send_next_op)
|
||||
|
||||
if tensor_recv_prev is not None:
|
||||
recv_prev_op = dist.P2POp(
|
||||
dist.irecv,
|
||||
@@ -510,7 +523,7 @@ class PipelineP2PCommunication:
|
||||
def __init__(self, stage_manager: PipelineStageManager) -> None:
|
||||
self.stage_manager = stage_manager
|
||||
|
||||
def recv_forward(self, prev_rank: int = None) -> Any:
|
||||
def recv_forward(self, prev_rank: Optional[int] = None, metadata_recv: Optional[P2PMetadata] = None) -> Any:
|
||||
"""Copy the forward output from the previous stage in pipeline as the input tensor of this stage.
|
||||
|
||||
Args:
|
||||
@@ -522,11 +535,16 @@ class PipelineP2PCommunication:
|
||||
if prev_rank is None:
|
||||
prev_rank = self.stage_manager.get_prev_rank()
|
||||
cur_rank = self.stage_manager.get_rank()
|
||||
input_tensor = _recv_object(prev_rank, cur_rank, self.stage_manager.get_p2p_process_group(prev_rank, cur_rank))
|
||||
input_tensor = _recv_object(
|
||||
prev_rank,
|
||||
cur_rank,
|
||||
self.stage_manager.get_p2p_process_group(prev_rank, cur_rank),
|
||||
metadata_recv=metadata_recv,
|
||||
)
|
||||
|
||||
return input_tensor
|
||||
|
||||
def recv_backward(self, next_rank: int = None) -> Any:
|
||||
def recv_backward(self, next_rank: Optional[int] = None, metadata_recv: Optional[P2PMetadata] = None) -> Any:
|
||||
"""Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage.
|
||||
|
||||
Args:
|
||||
@@ -539,12 +557,15 @@ class PipelineP2PCommunication:
|
||||
next_rank = self.stage_manager.get_next_rank()
|
||||
cur_rank = self.stage_manager.get_rank()
|
||||
output_tensor_grad = _recv_object(
|
||||
next_rank, cur_rank, self.stage_manager.get_p2p_process_group(next_rank, cur_rank)
|
||||
next_rank,
|
||||
cur_rank,
|
||||
self.stage_manager.get_p2p_process_group(next_rank, cur_rank),
|
||||
metadata_recv=metadata_recv,
|
||||
)
|
||||
|
||||
return output_tensor_grad
|
||||
|
||||
def send_forward(self, output_object: Any, next_rank: int = None) -> None:
|
||||
def send_forward(self, output_object: Any, next_rank: Optional[int] = None, send_metadata: bool = True) -> None:
|
||||
"""Sends the input tensor to the next stage in pipeline.
|
||||
|
||||
Args:
|
||||
@@ -554,9 +575,15 @@ class PipelineP2PCommunication:
|
||||
if next_rank is None:
|
||||
next_rank = self.stage_manager.get_next_rank()
|
||||
cur_rank = self.stage_manager.get_rank()
|
||||
_send_object(output_object, cur_rank, next_rank, self.stage_manager.get_p2p_process_group(cur_rank, next_rank))
|
||||
_send_object(
|
||||
output_object,
|
||||
cur_rank,
|
||||
next_rank,
|
||||
self.stage_manager.get_p2p_process_group(cur_rank, next_rank),
|
||||
send_metadata=send_metadata,
|
||||
)
|
||||
|
||||
def send_backward(self, input_object: Any, prev_rank: int = None) -> None:
|
||||
def send_backward(self, input_object: Any, prev_rank: Optional[int] = None, send_metadata: bool = True) -> None:
|
||||
"""Sends the gradient tensor to the previous stage in pipeline.
|
||||
|
||||
Args:
|
||||
@@ -566,9 +593,22 @@ class PipelineP2PCommunication:
|
||||
if prev_rank is None:
|
||||
prev_rank = self.stage_manager.get_prev_rank()
|
||||
cur_rank = self.stage_manager.get_rank()
|
||||
_send_object(input_object, cur_rank, prev_rank, self.stage_manager.get_p2p_process_group(cur_rank, prev_rank))
|
||||
_send_object(
|
||||
input_object,
|
||||
cur_rank,
|
||||
prev_rank,
|
||||
self.stage_manager.get_p2p_process_group(cur_rank, prev_rank),
|
||||
send_metadata=send_metadata,
|
||||
)
|
||||
|
||||
def send_forward_recv_backward(self, input_object: Any, next_rank: int = None) -> Any:
|
||||
def send_forward_recv_backward(
|
||||
self,
|
||||
input_object: Any,
|
||||
next_rank: Optional[int] = None,
|
||||
send_metadata: bool = True,
|
||||
metadata_recv: Optional[P2PMetadata] = None,
|
||||
send_prior_fallback: Optional[bool] = None,
|
||||
) -> Any:
|
||||
"""Sends the gradient tensor to and copy the gradient tensor from the next stage in pipeline
|
||||
|
||||
Args:
|
||||
@@ -581,11 +621,24 @@ class PipelineP2PCommunication:
|
||||
cur_rank = self.stage_manager.get_rank()
|
||||
group = self.stage_manager.get_p2p_process_group(cur_rank, next_rank)
|
||||
return _communicate(
|
||||
input_object, next_rank, next_rank,
|
||||
send_group=group, recv_group=group,
|
||||
input_object,
|
||||
next_rank,
|
||||
next_rank,
|
||||
send_group=group,
|
||||
recv_group=group,
|
||||
send_metadata=send_metadata,
|
||||
metadata_recv=metadata_recv,
|
||||
send_prior_fallback=send_prior_fallback,
|
||||
)
|
||||
|
||||
def send_backward_recv_forward(self, input_object: Any, prev_rank: int = None) -> Any:
|
||||
def send_backward_recv_forward(
|
||||
self,
|
||||
input_object: Any,
|
||||
prev_rank: Optional[int] = None,
|
||||
send_metadata: bool = True,
|
||||
metadata_recv: Optional[P2PMetadata] = None,
|
||||
send_prior_fallback: Optional[bool] = None,
|
||||
) -> Any:
|
||||
"""Sends the gradient tensor to and copy the gradient tensor from the previous stage in pipeline
|
||||
|
||||
Args:
|
||||
@@ -597,37 +650,23 @@ class PipelineP2PCommunication:
|
||||
|
||||
cur_rank = self.stage_manager.get_rank()
|
||||
group = self.stage_manager.get_p2p_process_group(prev_rank, cur_rank)
|
||||
return _communicate(
|
||||
input_object, prev_rank, prev_rank,
|
||||
send_group=group, recv_group=group,
|
||||
)
|
||||
|
||||
def send_forward_recv_forward(self, input_object: Any, prev_rank: int = None, next_rank: int = None) -> Any:
|
||||
"""Sends the gradient tensor to the previous stage and copy the input tensor from the previous stage in pipeline.
|
||||
|
||||
Args:
|
||||
input_object (Any): Object to be sent.
|
||||
prev_rank (int, optional): The rank of the sender of the tensor
|
||||
next_rank (int, optional): The rank of the recipient of the tensor
|
||||
"""
|
||||
if prev_rank is None:
|
||||
prev_rank = self.stage_manager.get_prev_rank()
|
||||
if next_rank is None:
|
||||
next_rank = self.stage_manager.get_next_rank()
|
||||
|
||||
cur_rank = self.stage_manager.get_rank()
|
||||
recv_group = self.stage_manager.get_p2p_process_group(prev_rank, cur_rank)
|
||||
send_group = self.stage_manager.get_p2p_process_group(cur_rank, next_rank)
|
||||
return _communicate(
|
||||
input_object,
|
||||
send_dst=next_rank,
|
||||
recv_src=prev_rank,
|
||||
send_group=send_group,
|
||||
recv_group=recv_group,
|
||||
prev_rank,
|
||||
prev_rank,
|
||||
send_group=group,
|
||||
recv_group=group,
|
||||
send_metadata=send_metadata,
|
||||
metadata_recv=metadata_recv,
|
||||
send_prior_fallback=send_prior_fallback,
|
||||
)
|
||||
|
||||
def p2p_communicate(
|
||||
self, output_object: Any, recv_pre: bool, peer: int = None, comm_dtype: torch.dtype = torch.float16
|
||||
self,
|
||||
output_object: Any,
|
||||
recv_pre: bool,
|
||||
next_rank: Optional[int] = None,
|
||||
comm_dtype: torch.dtype = torch.float16,
|
||||
) -> None:
|
||||
"""
|
||||
Sends the input tensor to the next stage in pipeline, using `P2Pop` in torch.
|
||||
@@ -636,10 +675,14 @@ class PipelineP2PCommunication:
|
||||
output_object (Any): Object to be sent.
|
||||
next_rank (int, optional): The rank of the recipient of the tensor.
|
||||
"""
|
||||
if peer is None:
|
||||
peer = self.stage_manager.get_next_rank()
|
||||
if next_rank is None:
|
||||
next_rank = self.stage_manager.get_next_rank()
|
||||
cur_rank = self.stage_manager.get_rank()
|
||||
recv_tensor = _p2p_comm(
|
||||
output_object, recv_pre, peer, self.stage_manager.get_p2p_process_group(cur_rank, peer), comm_dtype
|
||||
output_object,
|
||||
recv_pre,
|
||||
next_rank,
|
||||
self.stage_manager.get_p2p_process_group(cur_rank, next_rank),
|
||||
comm_dtype,
|
||||
)
|
||||
return recv_tensor
|
||||
|
@@ -1,14 +1,14 @@
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Iterable, List, Optional, Union
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.cuda
|
||||
from torch.nn import Module
|
||||
from torch.nn import Module, ModuleList
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.interface import OptimizerWrapper
|
||||
from colossalai.pipeline.p2p import PipelineP2PCommunication
|
||||
from colossalai.pipeline.p2p import PipelineP2PCommunication, create_send_metadata
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
|
||||
from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, model_forward, retain_grad, to_device
|
||||
@@ -16,18 +16,35 @@ from .base import PipelineSchedule
|
||||
|
||||
|
||||
class InterleavedSchedule(PipelineSchedule):
|
||||
def __init__(self, num_microbatches: int, num_model_chunks: int, stage_manager: PipelineStageManager) -> None:
|
||||
self.num_model_chunks = num_model_chunks
|
||||
assert (
|
||||
num_microbatches % self.num_model_chunks == 0
|
||||
), "Number of microbatches should be an integer multiple of number of model chunks"
|
||||
def __init__(
|
||||
self,
|
||||
stage_manager: PipelineStageManager,
|
||||
num_model_chunks: int,
|
||||
num_microbatch: Optional[int] = None,
|
||||
microbatch_size: Optional[int] = None,
|
||||
enable_metadata_cache: bool = True,
|
||||
) -> None:
|
||||
super().__init__(stage_manager)
|
||||
assert (
|
||||
num_microbatch is not None or microbatch_size is not None
|
||||
), "Either num_microbatch or microbatch_size should be provided"
|
||||
|
||||
self.comm = PipelineP2PCommunication(stage_manager)
|
||||
self.num_microbatches = num_microbatches
|
||||
self.batch: Optional[Any] = None
|
||||
self.batch_size: Optional[int] = None
|
||||
self.microbatch_offset: Optional[int] = None
|
||||
self.microbatch_size: Optional[int] = None
|
||||
self.num_microbatch = num_microbatch
|
||||
self.microbatch_size = microbatch_size
|
||||
self.num_model_chunks = num_model_chunks
|
||||
|
||||
self.batch: Any
|
||||
self.batch_size: int
|
||||
self.last_batch_size: Optional[int] = None
|
||||
self.microbatch_offset: List[int]
|
||||
|
||||
# P2PMeta cache
|
||||
self.enable_metadata_cache = enable_metadata_cache
|
||||
self.send_tensor_metadata = True
|
||||
self.send_grad_metadata = True
|
||||
self.tensor_metadata_recv = None
|
||||
self.grad_metadata_recv = None
|
||||
|
||||
def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None:
|
||||
"""Load a batch from data iterator.
|
||||
@@ -39,11 +56,37 @@ class InterleavedSchedule(PipelineSchedule):
|
||||
batch = next(data_iter)
|
||||
if device is not None:
|
||||
batch = tree_map(partial(to_device, device=device), batch)
|
||||
|
||||
self.microbatch_offset = [0 for _ in range(self.num_model_chunks)]
|
||||
self.batch = batch
|
||||
self.batch_size = get_batch_size(batch)
|
||||
self.microbatch_offset = [0 for _ in range(self.num_model_chunks)]
|
||||
assert self.batch_size % self.num_microbatches == 0, "Batch size should divided by the number of microbatches"
|
||||
self.microbatch_size = self.batch_size // self.num_microbatches
|
||||
|
||||
if self.microbatch_size is None:
|
||||
assert self.batch_size % self.num_microbatch == 0, "Batch size should divided by the number of microbatch"
|
||||
self.microbatch_size = self.batch_size // self.num_microbatch
|
||||
if self.num_microbatch is None:
|
||||
assert self.batch_size % self.microbatch_size == 0, "Batch size should divided by the microbatch size"
|
||||
self.num_microbatch = self.batch_size // self.microbatch_size
|
||||
|
||||
if not self.forward_only:
|
||||
assert self.last_batch_size is None or self.last_batch_size == self.batch_size
|
||||
assert self.batch_size == self.microbatch_size * self.num_microbatch
|
||||
|
||||
assert (
|
||||
self.num_microbatch % self.stage_manager.num_stages == 0
|
||||
), "Number of microbatch should be an integer multiple of number of pipeline parallel devices"
|
||||
|
||||
if self.forward_only:
|
||||
self.num_microbatch = (self.batch_size - 1) // self.microbatch_size + 1
|
||||
# NOTE: disable metadata cache when batch size changes (not valid anymore)
|
||||
if self.batch_size != self.last_batch_size:
|
||||
self.enable_metadata_cache = False
|
||||
self.send_tensor_metadata = True
|
||||
self.send_grad_metadata = True
|
||||
self.tensor_metadata_recv = None
|
||||
self.grad_metadata_recv = None
|
||||
|
||||
self.last_batch_size = self.batch_size
|
||||
|
||||
def load_micro_batch(self, model_chunk_id: int) -> Any:
|
||||
"""Load a micro batch from the current batch.
|
||||
@@ -54,11 +97,12 @@ class InterleavedSchedule(PipelineSchedule):
|
||||
Returns:
|
||||
Any: Micro batch.
|
||||
"""
|
||||
assert self.microbatch_offset[model_chunk_id] <= self.batch_size, "Microbatches exhausted"
|
||||
micro_batch = get_micro_batch(self.batch, self.microbatch_offset[model_chunk_id], self.microbatch_size)
|
||||
self.microbatch_offset[model_chunk_id] += self.microbatch_size
|
||||
return tree_map(partial(to_device, device=get_accelerator().get_current_device()), micro_batch)
|
||||
|
||||
def get_model_chunk_id(self, microbatch_id: int, forward: bool) -> int:
|
||||
def get_model_chunk_id(self, microbatch_id: int, is_forward: bool) -> int:
|
||||
"""Helper method to get the model chunk ID given the iteration number.
|
||||
|
||||
Args:
|
||||
@@ -68,38 +112,13 @@ class InterleavedSchedule(PipelineSchedule):
|
||||
Returns:
|
||||
int: The model chunk idx of the input microbatch_id
|
||||
"""
|
||||
microbatch_id_in_group = (microbatch_id) % (self.stage_manager.num_stages * self.num_model_chunks)
|
||||
assert microbatch_id < self.num_microbatch * self.num_model_chunks
|
||||
microbatch_id_in_group = microbatch_id % (self.stage_manager.num_stages * self.num_model_chunks)
|
||||
model_chunk_id = microbatch_id_in_group // self.stage_manager.num_stages
|
||||
if not forward:
|
||||
if not is_forward:
|
||||
model_chunk_id = self.num_model_chunks - model_chunk_id - 1
|
||||
return model_chunk_id
|
||||
|
||||
def is_first_stage(self, model_chunk_id: int) -> bool:
|
||||
"""Is the current virtual stage the first stage
|
||||
|
||||
Args:
|
||||
model_chunk_id (int): The current model chunk idx.
|
||||
|
||||
Returns:
|
||||
bool: Whether the current virtual stage is the first stage.
|
||||
"""
|
||||
if self.stage_manager.is_first_stage() and model_chunk_id == 0:
|
||||
return True
|
||||
return False
|
||||
|
||||
def is_last_stage(self, model_chunk_id: int) -> bool:
|
||||
"""Is the current virtual stage the last stage
|
||||
|
||||
Args:
|
||||
model_chunk_id (int): The current model chunk idx.
|
||||
|
||||
Returns:
|
||||
bool: Whether the current virtual stage is the last stage.
|
||||
"""
|
||||
if self.stage_manager.is_last_stage() and model_chunk_id == self.num_model_chunks - 1:
|
||||
return True
|
||||
return False
|
||||
|
||||
def recv_forward(self, model_chunk_id: int, prev_rank: int = None) -> Any:
|
||||
"""Copy the forward output from the previous stage in pipeline as the input tensor of this stage.
|
||||
For interleaved 1F1B.
|
||||
@@ -111,12 +130,13 @@ class InterleavedSchedule(PipelineSchedule):
|
||||
Returns:
|
||||
Any: The input tensor or input tensor list.
|
||||
"""
|
||||
if self.is_first_stage(model_chunk_id):
|
||||
input_tensor = None
|
||||
else:
|
||||
input_tensor = self.comm.recv_forward(prev_rank)
|
||||
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
|
||||
if not self.stage_manager.is_first_stage():
|
||||
input_tensor = self.comm.recv_forward(prev_rank, metadata_recv=self.tensor_metadata_recv)
|
||||
if self.enable_metadata_cache and self.tensor_metadata_recv is None:
|
||||
self.tensor_metadata_recv = create_send_metadata(input_tensor)
|
||||
|
||||
return input_tensor
|
||||
return input_tensor
|
||||
|
||||
def recv_backward(self, model_chunk_id: int, next_rank: int = None) -> Any:
|
||||
"""Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage.
|
||||
@@ -129,14 +149,15 @@ class InterleavedSchedule(PipelineSchedule):
|
||||
Returns:
|
||||
Any: The input gradient tensor or gradient tensor list.
|
||||
"""
|
||||
if self.is_last_stage(model_chunk_id):
|
||||
output_tensor_grad = None
|
||||
else:
|
||||
output_tensor_grad = self.comm.recv_backward(next_rank)
|
||||
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
|
||||
if not self.stage_manager.is_last_stage():
|
||||
output_tensor_grad = self.comm.recv_backward(next_rank, metadata_recv=self.grad_metadata_recv)
|
||||
if self.enable_metadata_cache and self.grad_metadata_recv is None:
|
||||
self.grad_metadata_recv = create_send_metadata(output_tensor_grad)
|
||||
|
||||
return output_tensor_grad
|
||||
return output_tensor_grad
|
||||
|
||||
def send_forward(self, model_chunk_id, output_object: Any, next_rank: int = None) -> None:
|
||||
def send_forward(self, model_chunk_id: int, output_tensor: Any, next_rank: int = None) -> None:
|
||||
"""Sends the input tensor to the next stage in pipeline.
|
||||
For interleaved 1F1B.
|
||||
|
||||
@@ -145,10 +166,12 @@ class InterleavedSchedule(PipelineSchedule):
|
||||
output_object (Any): Object to be sent.
|
||||
next_rank (int, optional): The rank of the recipient of the tensor.
|
||||
"""
|
||||
if not self.is_last_stage(model_chunk_id):
|
||||
self.comm.send_forward(output_object, next_rank)
|
||||
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
|
||||
if not self.stage_manager.is_last_stage():
|
||||
self.comm.send_forward(output_tensor, next_rank, send_metadata=self.send_tensor_metadata)
|
||||
self.send_tensor_metadata = not self.enable_metadata_cache
|
||||
|
||||
def send_backward(self, model_chunk_id, input_object: Any, prev_rank: int = None) -> None:
|
||||
def send_backward(self, model_chunk_id: int, input_tensor_grad: Any, prev_rank: int = None) -> None:
|
||||
"""Sends the gradient tensor to the previous stage in pipeline.
|
||||
For interleaved 1F1B.
|
||||
|
||||
@@ -157,12 +180,102 @@ class InterleavedSchedule(PipelineSchedule):
|
||||
input_object (Any): Object to be sent.
|
||||
prev_rank (int, optional): The rank of the recipient of the tensor
|
||||
"""
|
||||
if not self.is_first_stage(model_chunk_id):
|
||||
self.comm.send_backward(input_object, prev_rank)
|
||||
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
|
||||
if not self.stage_manager.is_first_stage():
|
||||
self.comm.send_backward(input_tensor_grad, prev_rank, send_metadata=self.send_grad_metadata)
|
||||
self.send_grad_metadata = not self.enable_metadata_cache
|
||||
|
||||
def send_forward_recv_backward(
|
||||
self,
|
||||
model_chunk_id_send: int,
|
||||
model_chunk_id_recv: int,
|
||||
output_tensor: Any,
|
||||
next_rank: Optional[int] = None,
|
||||
send_prior_fallback: Optional[bool] = None,
|
||||
) -> Any:
|
||||
with self.stage_manager.switch_model_chunk_id(model_chunk_id_send):
|
||||
send_data = not self.stage_manager.is_last_stage()
|
||||
with self.stage_manager.switch_model_chunk_id(model_chunk_id_recv):
|
||||
recv_data = not self.stage_manager.is_last_stage()
|
||||
|
||||
if send_data and recv_data:
|
||||
if not self.send_forward_recv_backward and self.grad_metadata_recv is not None:
|
||||
send_prior_fallback = None # must not fallback
|
||||
output_tensor_grad = self.comm.send_forward_recv_backward(
|
||||
output_tensor,
|
||||
next_rank,
|
||||
send_metadata=self.send_tensor_metadata,
|
||||
metadata_recv=self.grad_metadata_recv,
|
||||
send_prior_fallback=send_prior_fallback,
|
||||
)
|
||||
self.send_tensor_metadata = not self.enable_metadata_cache
|
||||
if self.enable_metadata_cache and self.grad_metadata_recv is None:
|
||||
self.grad_metadata_recv = create_send_metadata(output_tensor_grad)
|
||||
return output_tensor_grad
|
||||
|
||||
# send only or recv only
|
||||
self.send_forward(model_chunk_id_send, output_tensor)
|
||||
return self.recv_backward(model_chunk_id_recv)
|
||||
|
||||
def send_backward_recv_forward(
|
||||
self,
|
||||
model_chunk_id_send: int,
|
||||
model_chunk_id_recv: int,
|
||||
input_tensor_grad: Any,
|
||||
prev_rank: Optional[int] = None,
|
||||
send_prior_fallback: Optional[bool] = None,
|
||||
) -> Any:
|
||||
with self.stage_manager.switch_model_chunk_id(model_chunk_id_send):
|
||||
send_data = not self.stage_manager.is_first_stage()
|
||||
with self.stage_manager.switch_model_chunk_id(model_chunk_id_recv):
|
||||
recv_data = not self.stage_manager.is_first_stage()
|
||||
|
||||
if send_data and recv_data:
|
||||
if not self.send_backward_recv_backward and self.tensor_metadata_recv is not None:
|
||||
send_prior_fallback = None # must not fallback
|
||||
input_tensor = self.comm.send_backward_recv_forward(
|
||||
input_tensor_grad,
|
||||
prev_rank,
|
||||
send_metadata=self.send_grad_metadata,
|
||||
metadata_recv=self.tensor_metadata_recv,
|
||||
send_prior_fallback=send_prior_fallback,
|
||||
)
|
||||
self.send_grad_metadata = not self.enable_metadata_cache
|
||||
if self.enable_metadata_cache and self.tensor_metadata_recv is None:
|
||||
self.tensor_metadata_recv = create_send_metadata(input_tensor)
|
||||
return input_tensor
|
||||
|
||||
# send only or recv only
|
||||
self.send_backward(model_chunk_id_send, input_tensor_grad)
|
||||
return self.recv_forward(model_chunk_id_recv)
|
||||
|
||||
def send_forward_recv_forward(
|
||||
self, model_chunk_id_send: int, model_chunk_id_recv: int, output_tensor: Any, send_prior: bool
|
||||
):
|
||||
if send_prior:
|
||||
self.send_forward(model_chunk_id_send, output_tensor)
|
||||
input_tensor = self.recv_forward(model_chunk_id_recv)
|
||||
else:
|
||||
input_tensor = self.recv_forward(model_chunk_id_recv)
|
||||
self.send_forward(model_chunk_id_send, output_tensor)
|
||||
|
||||
return input_tensor
|
||||
|
||||
def send_backward_recv_backward(
|
||||
self, model_chunk_id_send: int, model_chunk_id_recv: int, input_tensor_grad: Any, send_prior: bool
|
||||
):
|
||||
if send_prior:
|
||||
self.send_backward(model_chunk_id_send, input_tensor_grad)
|
||||
output_tensor_grad = self.recv_backward(model_chunk_id_recv)
|
||||
else:
|
||||
output_tensor_grad = self.recv_backward(model_chunk_id_recv)
|
||||
self.send_backward(model_chunk_id_send, input_tensor_grad)
|
||||
|
||||
return output_tensor_grad
|
||||
|
||||
def forward_step(
|
||||
self,
|
||||
model_chunk: Module,
|
||||
model_chunk: Union[ModuleList, Module],
|
||||
model_chunk_id: int,
|
||||
input_obj: Optional[dict],
|
||||
criterion: Callable,
|
||||
@@ -171,7 +284,7 @@ class InterleavedSchedule(PipelineSchedule):
|
||||
) -> Union[torch.Tensor, dict]:
|
||||
"""Forward one step of the pipeline
|
||||
Args:
|
||||
model (Module): Model Chunk to be run
|
||||
model (ModuleList or Module): Model Chunk to be run
|
||||
input_obj (Optional[dict]): The output from the previous stage. If it is the first stage, the `input_obj` is None.
|
||||
criterion (Callable): Criterion to calculate loss.
|
||||
accum_loss (Optional[torch.Tensor], optional): Accumulated loss. Defaults to None.
|
||||
@@ -184,17 +297,25 @@ class InterleavedSchedule(PipelineSchedule):
|
||||
|
||||
# for the first stage, input_obj is None
|
||||
# for the non-first stage, input_obj is the output of the previous stage and it's must be a dict
|
||||
output_obj = model_forward(model_chunk[model_chunk_id], micro_batch, input_obj)
|
||||
|
||||
if self.is_last_stage(model_chunk_id):
|
||||
loss = criterion(output_obj, micro_batch) / self.num_microbatches
|
||||
if accum_loss is not None:
|
||||
accum_loss.add_(loss.detach())
|
||||
if outputs is not None:
|
||||
outputs.append(tree_map(detach, output_obj))
|
||||
return loss
|
||||
else:
|
||||
return output_obj
|
||||
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
|
||||
if isinstance(model_chunk, ModuleList):
|
||||
output_obj = model_forward(model_chunk[model_chunk_id], micro_batch, input_obj)
|
||||
else:
|
||||
# NOTE: in shardformer, each device still has the entire model, so we need to use relevant stage layers
|
||||
internal_inputs = {} if input_obj is None else input_obj
|
||||
internal_inputs["stage_index"] = self.stage_manager.stage_indices[model_chunk_id]
|
||||
output_obj = model_forward(model_chunk, micro_batch, internal_inputs)
|
||||
|
||||
if self.stage_manager.is_last_stage():
|
||||
loss = criterion(output_obj, micro_batch) / self.num_microbatch
|
||||
if accum_loss is not None:
|
||||
accum_loss.add_(loss.detach())
|
||||
if outputs is not None:
|
||||
outputs.append(tree_map(detach, output_obj))
|
||||
return loss
|
||||
else:
|
||||
return output_obj
|
||||
|
||||
def backward_step(
|
||||
self,
|
||||
@@ -241,19 +362,193 @@ class InterleavedSchedule(PipelineSchedule):
|
||||
input_obj_grad[k] = v.grad
|
||||
return input_obj_grad
|
||||
|
||||
def run_forward_only(
|
||||
self,
|
||||
model_chunk: Union[ModuleList, Module],
|
||||
data_iter: Iterable,
|
||||
criterion: Callable[..., Any],
|
||||
return_loss: bool = False,
|
||||
return_outputs: bool = False,
|
||||
) -> Dict:
|
||||
assert self.forward_only
|
||||
|
||||
self.load_batch(data_iter)
|
||||
|
||||
outputs = [] if return_outputs and self.stage_manager.is_last_stage(ignore_chunk=True) else None
|
||||
|
||||
accum_loss = None
|
||||
if return_loss and self.stage_manager.is_last_stage(ignore_chunk=True):
|
||||
accum_loss = torch.scalar_tensor(0, device=get_current_device())
|
||||
|
||||
model_chunk_id = self.get_model_chunk_id(0, is_forward=True)
|
||||
input_obj = self.recv_forward(model_chunk_id)
|
||||
|
||||
for i in range(self.num_microbatch * self.num_model_chunks):
|
||||
last_iteration = i == self.num_microbatch * self.num_model_chunks - 1
|
||||
model_chunk_id = self.get_model_chunk_id(i, is_forward=True)
|
||||
output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)
|
||||
|
||||
if not last_iteration:
|
||||
input_obj = self.send_forward_recv_forward(
|
||||
model_chunk_id_send=model_chunk_id,
|
||||
model_chunk_id_recv=self.get_model_chunk_id(i + 1, is_forward=True),
|
||||
output_tensor=output_obj,
|
||||
send_prior=self.stage_manager.stage % 2 == 0,
|
||||
)
|
||||
else:
|
||||
self.send_forward(model_chunk_id, output_obj)
|
||||
|
||||
if outputs is not None:
|
||||
outputs = merge_batch(outputs)
|
||||
return {"loss": accum_loss, "outputs": outputs}
|
||||
|
||||
def run_forward_backward(
|
||||
self,
|
||||
model_chunk: Union[ModuleList, Module],
|
||||
data_iter: Iterable,
|
||||
criterion: Callable[..., Any],
|
||||
optimizer: Optional[OptimizerWrapper] = None,
|
||||
return_loss: bool = False,
|
||||
return_outputs: bool = False,
|
||||
) -> Dict:
|
||||
"""
|
||||
Runs interleaved schedule, with communication between pipeline stages.
|
||||
"""
|
||||
assert not self.forward_only
|
||||
|
||||
self.load_batch(data_iter)
|
||||
|
||||
num_microbatch = self.num_microbatch * self.num_model_chunks
|
||||
num_warmup_microbatch = (self.stage_manager.num_stages - self.stage_manager.stage - 1) * 2
|
||||
num_warmup_microbatch += (self.num_model_chunks - 1) * self.stage_manager.num_stages
|
||||
num_warmup_microbatch = min(num_warmup_microbatch, num_microbatch)
|
||||
num_microbatch_remaining = num_microbatch - num_warmup_microbatch
|
||||
|
||||
# Input, output tensors only need to be saved when doing backward passes
|
||||
input_objs = [[] for _ in range(self.num_model_chunks)]
|
||||
output_objs = [[] for _ in range(self.num_model_chunks)]
|
||||
|
||||
outputs = [] if return_outputs and self.stage_manager.is_last_stage(ignore_chunk=True) else None
|
||||
|
||||
accum_loss = None
|
||||
if return_loss and self.stage_manager.is_last_stage(ignore_chunk=True):
|
||||
accum_loss = torch.scalar_tensor(0, device=get_current_device())
|
||||
|
||||
model_chunk_id = self.get_model_chunk_id(0, is_forward=True)
|
||||
input_obj = self.recv_forward(model_chunk_id)
|
||||
# Run warmup forward passes.
|
||||
for i in range(num_warmup_microbatch):
|
||||
last_iteration = i == num_warmup_microbatch - 1
|
||||
model_chunk_id = self.get_model_chunk_id(i, is_forward=True)
|
||||
output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)
|
||||
input_objs[model_chunk_id].append(input_obj)
|
||||
output_objs[model_chunk_id].append(output_obj)
|
||||
|
||||
if last_iteration and num_microbatch_remaining == 0:
|
||||
self.send_forward(model_chunk_id, output_obj)
|
||||
else:
|
||||
input_obj = self.send_forward_recv_forward(
|
||||
model_chunk_id_send=model_chunk_id,
|
||||
model_chunk_id_recv=self.get_model_chunk_id(i + 1, is_forward=True),
|
||||
output_tensor=output_obj,
|
||||
send_prior=self.stage_manager.stage % 2 == 0,
|
||||
)
|
||||
|
||||
if num_microbatch_remaining > 0:
|
||||
model_chunk_id = self.get_model_chunk_id(0, is_forward=False)
|
||||
output_obj_grad = self.recv_backward(model_chunk_id)
|
||||
|
||||
# Run 1F1B in steady state.
|
||||
for i in range(num_microbatch_remaining):
|
||||
last_iteration = i == num_microbatch_remaining - 1
|
||||
|
||||
model_chunk_id = self.get_model_chunk_id(i + num_warmup_microbatch, is_forward=True)
|
||||
output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)
|
||||
# Add input_obj and output_obj to end of list.
|
||||
input_objs[model_chunk_id].append(input_obj)
|
||||
output_objs[model_chunk_id].append(output_obj)
|
||||
|
||||
model_chunk_id = self.get_model_chunk_id(i, is_forward=False)
|
||||
# Pop output_obj and output_obj from the start of the list for the backward pass.
|
||||
_input_obj = input_objs[model_chunk_id].pop(0)
|
||||
_output_obj = output_objs[model_chunk_id].pop(0)
|
||||
input_obj_grad = self.backward_step(optimizer, _input_obj, _output_obj, output_obj_grad)
|
||||
|
||||
# NOTE: perform 2x communication for forward and backward
|
||||
def send_forward_recv_backward():
|
||||
if last_iteration and num_microbatch == num_microbatch_remaining:
|
||||
model_chunk_id = self.get_model_chunk_id(i + num_warmup_microbatch, is_forward=True)
|
||||
self.send_forward(model_chunk_id, output_obj)
|
||||
else:
|
||||
output_obj_grad = self.send_forward_recv_backward(
|
||||
model_chunk_id_send=self.get_model_chunk_id(i + num_warmup_microbatch, is_forward=True),
|
||||
model_chunk_id_recv=self.get_model_chunk_id(i + 1, is_forward=False),
|
||||
output_tensor=output_obj,
|
||||
send_prior_fallback=self.stage_manager.stage % 2 == 0,
|
||||
)
|
||||
return output_obj_grad
|
||||
|
||||
def send_backward_recv_forward():
|
||||
if last_iteration:
|
||||
model_chunk_id = self.get_model_chunk_id(i, is_forward=False)
|
||||
self.send_backward(model_chunk_id, input_obj_grad)
|
||||
else:
|
||||
input_obj = self.send_backward_recv_forward(
|
||||
model_chunk_id_send=self.get_model_chunk_id(i, is_forward=False),
|
||||
model_chunk_id_recv=self.get_model_chunk_id(i + num_warmup_microbatch + 1, is_forward=True),
|
||||
input_tensor_grad=input_obj_grad,
|
||||
send_prior_fallback=self.stage_manager.stage % 2 == 0 and i > 0,
|
||||
)
|
||||
return input_obj
|
||||
|
||||
if self.stage_manager.stage % 2 == 0:
|
||||
output_obj_grad = send_forward_recv_backward()
|
||||
input_obj = send_backward_recv_forward()
|
||||
else:
|
||||
input_obj = send_backward_recv_forward()
|
||||
output_obj_grad = send_forward_recv_backward()
|
||||
|
||||
if num_microbatch_remaining == 0:
|
||||
model_chunk_id = self.get_model_chunk_id(0, is_forward=False)
|
||||
output_obj_grad = self.recv_backward(model_chunk_id)
|
||||
# Run cooldown backward passes.
|
||||
for i in range(num_microbatch_remaining, num_microbatch):
|
||||
last_iteration = i == num_microbatch - 1
|
||||
model_chunk_id = self.get_model_chunk_id(i, is_forward=False)
|
||||
_input_obj = input_objs[model_chunk_id].pop(0)
|
||||
_output_obj = output_objs[model_chunk_id].pop(0)
|
||||
# output_obj_grad = self.recv_backward(model_chunk_id)
|
||||
input_obj_grad = self.backward_step(optimizer, _input_obj, _output_obj, output_obj_grad)
|
||||
|
||||
if not last_iteration:
|
||||
output_obj_grad = self.send_backward_recv_backward(
|
||||
model_chunk_id_send=self.get_model_chunk_id(i, is_forward=False),
|
||||
model_chunk_id_recv=self.get_model_chunk_id(i + 1, is_forward=False),
|
||||
input_tensor_grad=input_obj_grad,
|
||||
send_prior=self.stage_manager.stage % 2 == 0 and i > num_microbatch_remaining,
|
||||
)
|
||||
else:
|
||||
model_chunk_id = self.get_model_chunk_id(i, is_forward=False)
|
||||
self.send_backward(model_chunk_id, input_obj_grad)
|
||||
|
||||
assert all(len(v) == 0 for v in input_objs) and all(len(v) == 0 for v in output_objs)
|
||||
|
||||
if outputs is not None:
|
||||
outputs = merge_batch(outputs)
|
||||
return {"loss": accum_loss, "outputs": outputs}
|
||||
|
||||
def forward_backward_step(
|
||||
self,
|
||||
model_chunk: Module,
|
||||
model_chunk: Union[ModuleList, Module],
|
||||
data_iter: Iterable,
|
||||
criterion: Callable[..., Any],
|
||||
optimizer: Optional[OptimizerWrapper] = None,
|
||||
return_loss: bool = False,
|
||||
return_outputs: bool = False,
|
||||
) -> dict:
|
||||
"""Runs interleaved 1F1B schedule, with communication between pipeline stages.
|
||||
|
||||
"""
|
||||
Args:
|
||||
model_chunk (List[Module]): Model Chunk to be trained.
|
||||
model_chunk (ModuleList or Module): Model Chunk to be trained. Original interleaved uses a module list whereas shardformer uses entire model + layer specification
|
||||
data_iter (Iterable): Data iterator.
|
||||
criterion (Callable[[Any, Any], Tensor]): Criterion to be used. It should take two arguments: model outputs and inputs, and returns loss tensor.
|
||||
optimizer (OptimizerWrapper, optional): Optimizer to be used. Can be None when only forward is executed. Defaults to None.
|
||||
@@ -263,118 +558,15 @@ class InterleavedSchedule(PipelineSchedule):
|
||||
Returns:
|
||||
dict: A dict with keys: 'loss' and 'outputs'.
|
||||
"""
|
||||
forward_only = not torch.is_grad_enabled()
|
||||
self.forward_only = not torch.is_grad_enabled()
|
||||
if optimizer is None:
|
||||
assert forward_only, "Optimizer should be passed when doing backward."
|
||||
assert self.forward_only, "Optimizer should be passed when doing backward."
|
||||
|
||||
self.load_batch(data_iter)
|
||||
num_model_chunks = len(model_chunk)
|
||||
|
||||
# num_warmup_microbatches is the step when not all the processes are working
|
||||
num_microbatches = self.num_microbatches * num_model_chunks
|
||||
if forward_only:
|
||||
num_warmup_microbatches = num_microbatches
|
||||
if self.forward_only:
|
||||
result = self.run_forward_only(model_chunk, data_iter, criterion, return_loss, return_outputs)
|
||||
else:
|
||||
num_warmup_microbatches = (self.stage_manager.num_stages - self.stage_manager.stage - 1) * 2
|
||||
num_warmup_microbatches += (num_model_chunks - 1) * self.stage_manager.num_stages
|
||||
num_warmup_microbatches = min(num_warmup_microbatches, num_microbatches)
|
||||
result = self.run_forward_backward(
|
||||
model_chunk, data_iter, criterion, optimizer, return_loss, return_outputs
|
||||
)
|
||||
|
||||
num_microbatches_remaining = num_microbatches - num_warmup_microbatches
|
||||
|
||||
# Input, output tensors only need to be saved when doing backward passes
|
||||
input_objs = None
|
||||
output_objs = None
|
||||
|
||||
if not forward_only:
|
||||
input_objs = [[] for _ in range(num_model_chunks)]
|
||||
output_objs = [[] for _ in range(num_model_chunks)]
|
||||
|
||||
outputs = [] if return_outputs and self.stage_manager.is_last_stage() else None
|
||||
|
||||
if return_loss and self.stage_manager.is_last_stage():
|
||||
accum_loss = torch.zeros(1, device=get_accelerator().get_current_device())
|
||||
else:
|
||||
accum_loss = None
|
||||
|
||||
# for ranks except the first one, get into recv state
|
||||
# print(self.stage_manager.stage,num_microbatches, num_warmup_microbatches, num_microbatches_remaining)
|
||||
input_obj = self.recv_forward(0)
|
||||
input_objs[0].append(input_obj)
|
||||
# Run warmup forward passes.
|
||||
for i in range(num_warmup_microbatches):
|
||||
model_chunk_id = self.get_model_chunk_id(i, forward=True)
|
||||
|
||||
# recv first on first rank to avoid sending or recving at the same time
|
||||
if self.stage_manager.is_first_stage():
|
||||
input_obj = self.recv_forward(model_chunk_id)
|
||||
output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)
|
||||
self.send_forward(model_chunk_id, output_obj)
|
||||
if not forward_only:
|
||||
input_objs[model_chunk_id].append(input_obj)
|
||||
output_objs[model_chunk_id].append(output_obj)
|
||||
else:
|
||||
output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)
|
||||
if not forward_only:
|
||||
output_objs[model_chunk_id].append(output_obj)
|
||||
self.send_forward(model_chunk_id, output_obj)
|
||||
if num_microbatches_remaining == 0 and i + 1 == num_warmup_microbatches:
|
||||
break
|
||||
else:
|
||||
model_chunk_id = self.get_model_chunk_id(i + 1, forward=True)
|
||||
|
||||
input_obj = self.recv_forward(model_chunk_id)
|
||||
if not forward_only:
|
||||
input_objs[model_chunk_id].append(input_obj)
|
||||
|
||||
# Run 1F1B in steady state.
|
||||
for i in range(num_microbatches_remaining):
|
||||
model_chunk_id = self.get_model_chunk_id(i + num_warmup_microbatches, forward=True)
|
||||
last_iteration = i == (num_microbatches_remaining - 1)
|
||||
|
||||
output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)
|
||||
if forward_only:
|
||||
self.send_forward(model_chunk_id, output_obj)
|
||||
|
||||
if not last_iteration:
|
||||
input_obj = self.recv_forward(model_chunk_id)
|
||||
|
||||
else:
|
||||
self.send_forward(model_chunk_id, output_obj)
|
||||
# Add input_obj and output_obj to end of list.
|
||||
input_objs[model_chunk_id].append(input_obj)
|
||||
output_objs[model_chunk_id].append(output_obj)
|
||||
|
||||
model_chunk_id = self.get_model_chunk_id(i, forward=False)
|
||||
output_obj_grad = self.recv_backward(model_chunk_id)
|
||||
|
||||
# Pop output_obj and output_obj from the start of the list for
|
||||
# the backward pass.
|
||||
input_obj = input_objs[model_chunk_id].pop(0)
|
||||
output_obj = output_objs[model_chunk_id].pop(0)
|
||||
|
||||
# backward
|
||||
input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad)
|
||||
|
||||
if last_iteration:
|
||||
input_obj = None
|
||||
else:
|
||||
model_chunk_id = self.get_model_chunk_id(i + num_warmup_microbatches + 1, forward=True)
|
||||
input_obj = self.recv_forward(model_chunk_id)
|
||||
model_chunk_id = self.get_model_chunk_id(i, forward=False)
|
||||
self.send_backward(model_chunk_id, input_obj_grad)
|
||||
|
||||
# Run cooldown backward passes.
|
||||
if not forward_only:
|
||||
for i in range(num_microbatches_remaining, num_microbatches):
|
||||
model_chunk_id = self.get_model_chunk_id(i, forward=False)
|
||||
# print(f"{self.stage_manager.stage}/{model_chunk_id}: {len(input_objs[model_chunk_id])} {len(output_objs[model_chunk_id])} {i}")
|
||||
input_obj = input_objs[model_chunk_id].pop(0)
|
||||
output_obj = output_objs[model_chunk_id].pop(0)
|
||||
|
||||
output_obj_grad = self.recv_backward(model_chunk_id)
|
||||
input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad)
|
||||
self.send_backward(model_chunk_id, input_obj_grad)
|
||||
|
||||
if outputs is not None:
|
||||
outputs = merge_batch(outputs)
|
||||
return {"loss": accum_loss, "outputs": outputs}
|
||||
return result
|
||||
|
@@ -1,5 +1,5 @@
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Iterable, List, Optional, Union
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.cuda
|
||||
@@ -8,7 +8,7 @@ from torch.utils._pytree import tree_map
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||
from colossalai.pipeline.p2p import PipelineP2PCommunication
|
||||
from colossalai.pipeline.p2p import PipelineP2PCommunication, create_send_metadata
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
|
||||
from ._utils import (
|
||||
@@ -30,6 +30,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
||||
stage_manager: PipelineStageManager,
|
||||
num_microbatches: Optional[int] = None,
|
||||
microbatch_size: Optional[int] = None,
|
||||
enable_metadata_cache: bool = True,
|
||||
) -> None:
|
||||
"""1F1B pipeline schedule.
|
||||
|
||||
@@ -42,13 +43,21 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
||||
assert (
|
||||
num_microbatches is not None or microbatch_size is not None
|
||||
), "Either num_microbatches or microbatch_size should be provided"
|
||||
|
||||
self.comm = PipelineP2PCommunication(stage_manager)
|
||||
self.num_microbatches = num_microbatches
|
||||
self.microbatch_size = microbatch_size
|
||||
self.batch: Optional[Any] = None
|
||||
self.batch_size: Optional[int] = None
|
||||
self.last_batch_size: Optional[int] = None
|
||||
self.microbatch_offset: Optional[int] = None
|
||||
self._use_microbatch_size = num_microbatches is None
|
||||
|
||||
# P2PMeta cache
|
||||
self.enable_metadata_cache = enable_metadata_cache
|
||||
self.send_tensor_metadata = True
|
||||
self.send_grad_metadata = True
|
||||
self.tensor_metadata_recv = None
|
||||
self.grad_metadata_recv = None
|
||||
|
||||
def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None:
|
||||
"""Load a batch from data iterator.
|
||||
@@ -60,24 +69,45 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
||||
batch = next(data_iter)
|
||||
if device is not None:
|
||||
batch = tree_map(partial(to_device, device=device), batch)
|
||||
|
||||
self.microbatch_offset = 0
|
||||
self.batch = batch
|
||||
self.batch_size = get_batch_size(batch)
|
||||
self.microbatch_offset = 0
|
||||
if not self._use_microbatch_size:
|
||||
assert (
|
||||
self.batch_size % self.num_microbatches == 0
|
||||
), "Batch size should divided by the number of microbatches"
|
||||
|
||||
if self.microbatch_size is None:
|
||||
assert self.batch_size % self.num_microbatches == 0, "Batch size should divided by # microbatches"
|
||||
self.microbatch_size = self.batch_size // self.num_microbatches
|
||||
else:
|
||||
if self.num_microbatches is None:
|
||||
assert self.batch_size % self.microbatch_size == 0, "Batch size should divided by the microbatch size"
|
||||
self.num_microbatches = self.batch_size // self.microbatch_size
|
||||
|
||||
if not self.forward_only:
|
||||
assert self.last_batch_size is None or self.last_batch_size == self.batch_size
|
||||
assert self.batch_size == self.microbatch_size * self.num_microbatches
|
||||
|
||||
assert (
|
||||
self.num_microbatches >= self.stage_manager.num_stages
|
||||
), "Number of microbatch should be larger than number of stages"
|
||||
|
||||
if self.forward_only:
|
||||
self.num_microbatches = (self.batch_size - 1) // self.microbatch_size + 1
|
||||
# NOTE: disable metadata cache when batch size changes (not valid anymore)
|
||||
if self.batch_size != self.last_batch_size:
|
||||
self.enable_metadata_cache = False
|
||||
self.send_tensor_metadata = True
|
||||
self.send_grad_metadata = True
|
||||
self.tensor_metadata_recv = None
|
||||
self.grad_metadata_recv = None
|
||||
|
||||
self.last_batch_size = self.batch_size
|
||||
|
||||
def load_micro_batch(self) -> Any:
|
||||
"""Load a micro batch from the current batch.
|
||||
|
||||
Returns:
|
||||
Any: Micro batch.
|
||||
"""
|
||||
assert self.microbatch_offset <= self.batch_size, "Microbatches exhausted"
|
||||
micro_batch = get_micro_batch(self.batch, self.microbatch_offset, self.microbatch_size)
|
||||
self.microbatch_offset += self.microbatch_size
|
||||
return tree_map(partial(to_device, device=get_accelerator().get_current_device()), micro_batch)
|
||||
@@ -92,12 +122,12 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
||||
Returns:
|
||||
Any: The input tensor or input tensor list.
|
||||
"""
|
||||
if self.stage_manager.is_first_stage():
|
||||
input_tensor = None
|
||||
else:
|
||||
input_tensor = self.comm.recv_forward(prev_rank)
|
||||
if not self.stage_manager.is_first_stage():
|
||||
input_tensor = self.comm.recv_forward(prev_rank, metadata_recv=self.tensor_metadata_recv)
|
||||
if self.enable_metadata_cache and self.tensor_metadata_recv is None:
|
||||
self.tensor_metadata_recv = create_send_metadata(input_tensor)
|
||||
|
||||
return input_tensor
|
||||
return input_tensor
|
||||
|
||||
def recv_backward(self, next_rank: int = None) -> Any:
|
||||
"""Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage.
|
||||
@@ -109,14 +139,14 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
||||
Returns:
|
||||
Any: The input gradient tensor or gradient tensor list.
|
||||
"""
|
||||
if self.stage_manager.is_last_stage():
|
||||
output_tensor_grad = None
|
||||
else:
|
||||
output_tensor_grad = self.comm.recv_backward(next_rank)
|
||||
if not self.stage_manager.is_last_stage():
|
||||
output_tensor_grad = self.comm.recv_backward(next_rank, metadata_recv=self.grad_metadata_recv)
|
||||
if self.enable_metadata_cache and self.grad_metadata_recv is None:
|
||||
self.grad_metadata_recv = create_send_metadata(output_tensor_grad)
|
||||
|
||||
return output_tensor_grad
|
||||
return output_tensor_grad
|
||||
|
||||
def send_forward(self, output_object: Any, next_rank: int = None) -> None:
|
||||
def send_forward(self, output_tensor: Any, next_rank: int = None) -> None:
|
||||
"""Sends the input tensor to the next stage in pipeline.
|
||||
For 1F1B.
|
||||
|
||||
@@ -125,20 +155,10 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
||||
next_rank (int, optional): The rank of the recipient of the tensor.
|
||||
"""
|
||||
if not self.stage_manager.is_last_stage():
|
||||
self.comm.send_forward(output_object, next_rank)
|
||||
self.comm.send_forward(output_tensor, next_rank, send_metadata=self.send_tensor_metadata)
|
||||
self.send_tensor_metadata = not self.enable_metadata_cache
|
||||
|
||||
def send_forward_recv_backward(self, output_object: Any, next_rank: int = None) -> Any:
|
||||
"""Sends the input tensor to the next stage and copy the gradient tensor from the next stage in pipeline.
|
||||
For 1F1B.
|
||||
|
||||
Args:
|
||||
output_object (Any): Object to be sent.
|
||||
next_rank (int, optional): The rank of the recipient of the tensor.
|
||||
"""
|
||||
if not self.stage_manager.is_last_stage():
|
||||
return self.comm.send_forward_recv_backward(output_object, next_rank)
|
||||
|
||||
def send_backward(self, input_object: Any, prev_rank: int = None) -> None:
|
||||
def send_backward(self, input_tensor_grad: Any, prev_rank: int = None) -> None:
|
||||
"""Sends the gradient tensor to the previous stage in pipeline.
|
||||
For 1F1B.
|
||||
|
||||
@@ -147,9 +167,38 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
||||
prev_rank (int, optional): The rank of the recipient of the tensor
|
||||
"""
|
||||
if not self.stage_manager.is_first_stage():
|
||||
self.comm.send_backward(input_object, prev_rank)
|
||||
self.comm.send_backward(input_tensor_grad, prev_rank, send_metadata=self.send_grad_metadata)
|
||||
self.send_grad_metadata = not self.enable_metadata_cache
|
||||
|
||||
def send_backward_recv_forward(self, output_object: Any, prev_rank: int = None) -> Any:
|
||||
def send_forward_recv_backward(
|
||||
self, output_tensor: Any, next_rank: int = None, send_prior_fallback: Optional[bool] = None
|
||||
) -> Any:
|
||||
"""Sends the input tensor to the next stage and copy the gradient tensor from the next stage in pipeline.
|
||||
For 1F1B.
|
||||
|
||||
Args:
|
||||
output_object (Any): Object to be sent.
|
||||
next_rank (int, optional): The rank of the recipient of the tensor.
|
||||
"""
|
||||
if not self.stage_manager.is_last_stage():
|
||||
if not self.send_tensor_metadata and self.grad_metadata_recv is not None:
|
||||
send_prior_fallback = None # must not fallback
|
||||
output_tensor_grad = self.comm.send_forward_recv_backward(
|
||||
output_tensor,
|
||||
next_rank,
|
||||
send_metadata=self.send_tensor_metadata,
|
||||
metadata_recv=self.grad_metadata_recv,
|
||||
send_prior_fallback=send_prior_fallback,
|
||||
)
|
||||
self.send_tensor_metadata = not self.enable_metadata_cache
|
||||
if self.enable_metadata_cache and self.grad_metadata_recv is None:
|
||||
self.grad_metadata_recv = create_send_metadata(output_tensor_grad)
|
||||
|
||||
return output_tensor_grad
|
||||
|
||||
def send_backward_recv_forward(
|
||||
self, input_tensor_grad: Any, prev_rank: int = None, send_prior_fallback: Optional[bool] = None
|
||||
) -> Any:
|
||||
"""Sends the gradient tensor to the previous stage and copy the input tensor from the previous stage in pipeline.
|
||||
For 1F1B.
|
||||
|
||||
@@ -158,23 +207,20 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
||||
prev_rank (int, optional): The rank of the recipient of the tensor.
|
||||
"""
|
||||
if not self.stage_manager.is_first_stage():
|
||||
return self.comm.send_backward_recv_forward(output_object, prev_rank)
|
||||
if not self.send_grad_metadata and self.tensor_metadata_recv is not None:
|
||||
send_prior_fallback = None # must not fallback
|
||||
input_tensor = self.comm.send_backward_recv_forward(
|
||||
input_tensor_grad,
|
||||
prev_rank,
|
||||
send_metadata=self.send_grad_metadata,
|
||||
metadata_recv=self.tensor_metadata_recv,
|
||||
send_prior_fallback=send_prior_fallback,
|
||||
)
|
||||
self.send_grad_metadata = not self.enable_metadata_cache
|
||||
if self.enable_metadata_cache and self.tensor_metadata_recv is None:
|
||||
self.tensor_metadata_recv = create_send_metadata(input_tensor)
|
||||
|
||||
def send_forward_recv_forward(self, input_object: Any, prev_rank: int = None, next_rank: int = None) -> Any:
|
||||
"""Sends the input tensor to the next stage and copy the input tensor from the previous stage in pipeline.
|
||||
For 1F1B.
|
||||
|
||||
Args:
|
||||
input_object (Any): Object to be sent.
|
||||
prev_rank (int, optional): The previous rank of the recipient of the tensor.
|
||||
next_rank (int, optional): The next rank of the recipient of the tensor.
|
||||
"""
|
||||
if self.stage_manager.is_first_stage():
|
||||
return self.comm.send_forward(input_object, next_rank)
|
||||
elif self.stage_manager.is_last_stage():
|
||||
return self.comm.recv_forward(prev_rank)
|
||||
else:
|
||||
return self.comm.send_forward_recv_forward(input_object, prev_rank, next_rank)
|
||||
return input_tensor
|
||||
|
||||
def forward_step(
|
||||
self,
|
||||
@@ -254,7 +300,38 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
||||
input_obj_grad[k] = v.grad
|
||||
return input_obj_grad
|
||||
|
||||
def forward_backward_step(
|
||||
def run_forward_only(
|
||||
self,
|
||||
model: Module,
|
||||
data_iter: Iterable,
|
||||
criterion: Callable[..., Any],
|
||||
return_loss: bool = False,
|
||||
return_outputs: bool = False,
|
||||
) -> Dict:
|
||||
"""
|
||||
Runs forward only schedule, with communication between pipeline stages.
|
||||
"""
|
||||
assert self.forward_only
|
||||
|
||||
self.load_batch(data_iter)
|
||||
|
||||
accum_loss = None
|
||||
if return_loss and self.stage_manager.is_last_stage():
|
||||
accum_loss = torch.scalar_tensor(0, device=get_accelerator().get_current_device())
|
||||
outputs = [] if return_outputs and self.stage_manager.is_last_stage() else None
|
||||
|
||||
for _ in range(self.num_microbatches):
|
||||
input_obj = self.recv_forward()
|
||||
output_obj = self.forward_step(model, input_obj, criterion, accum_loss, outputs)
|
||||
self.send_forward(output_obj)
|
||||
|
||||
if outputs is not None:
|
||||
if isinstance(model, ModelWrapper):
|
||||
model = model.unwrap()
|
||||
outputs = merge_batch(outputs, getattr(model, "batch_size_dim", 0))
|
||||
return {"loss": accum_loss, "outputs": outputs}
|
||||
|
||||
def run_forward_backward(
|
||||
self,
|
||||
model: Module,
|
||||
data_iter: Iterable,
|
||||
@@ -262,23 +339,11 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
||||
optimizer: Optional[OptimizerWrapper] = None,
|
||||
return_loss: bool = False,
|
||||
return_outputs: bool = False,
|
||||
) -> dict:
|
||||
"""Runs non-interleaved 1F1B schedule, with communication between pipeline stages.
|
||||
|
||||
Args:
|
||||
model (Module): Model to be trained.
|
||||
data_iter (Iterable): Data iterator.
|
||||
criterion (Callable[[Any, Any], Tensor]): Criterion to be used. It should take two arguments: model outputs and inputs, and returns loss tensor.
|
||||
optimizer (OptimizerWrapper, optional): Optimizer to be used. Can be None when only forward is executed. Defaults to None.
|
||||
return_loss (bool, optional): Whether to return loss. Defaults to False. Whether to return loss.
|
||||
return_outputs (bool, optional): Whether to return model outputs. Defaults to False. Whether to return model outputs.
|
||||
|
||||
Returns:
|
||||
dict: A dict with keys: 'loss' and 'outputs'.
|
||||
) -> Dict:
|
||||
"""
|
||||
forward_only = not torch.is_grad_enabled()
|
||||
if optimizer is None:
|
||||
assert forward_only, "Optimizer should be passed when doing backward."
|
||||
Runs non-interleaved 1F1B schedule, with communication between pipeline stages.
|
||||
"""
|
||||
assert not self.forward_only
|
||||
|
||||
self.load_batch(data_iter)
|
||||
|
||||
@@ -288,30 +353,20 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
||||
num_microbatches_remaining = self.num_microbatches - num_warmup_microbatches
|
||||
|
||||
# Input, output tensors only need to be saved when doing backward passes
|
||||
input_objs = None
|
||||
output_objs = None
|
||||
input_objs, output_objs = [], []
|
||||
|
||||
if not forward_only:
|
||||
input_objs = []
|
||||
output_objs = []
|
||||
|
||||
outputs = [] if return_outputs and self.stage_manager.is_last_stage() else None
|
||||
accum_loss = None
|
||||
if return_loss and self.stage_manager.is_last_stage():
|
||||
accum_loss = torch.zeros(1, device=get_accelerator().get_current_device())
|
||||
else:
|
||||
accum_loss = None
|
||||
accum_loss = torch.scalar_tensor(0, device=get_current_device())
|
||||
outputs = [] if return_outputs and self.stage_manager.is_last_stage() else None
|
||||
|
||||
# Run warmup forward passes.
|
||||
for i in range(num_warmup_microbatches):
|
||||
input_obj = self.recv_forward()
|
||||
|
||||
output_obj = self.forward_step(model, input_obj, criterion, accum_loss, outputs)
|
||||
|
||||
self.send_forward(output_obj)
|
||||
|
||||
if not forward_only:
|
||||
input_objs.append(input_obj)
|
||||
output_objs.append(output_obj)
|
||||
input_objs.append(input_obj)
|
||||
output_objs.append(output_obj)
|
||||
|
||||
# Before running 1F1B, need to receive first forward tensor.
|
||||
# If all microbatches are run in warmup / cooldown phase, then no need to
|
||||
@@ -324,44 +379,72 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
||||
last_iteration = i == (num_microbatches_remaining - 1)
|
||||
|
||||
output_obj = self.forward_step(model, input_obj, criterion, accum_loss, outputs)
|
||||
if forward_only:
|
||||
self.send_forward(output_obj)
|
||||
output_obj_grad = self.send_forward_recv_backward(
|
||||
output_obj, send_prior_fallback=self.stage_manager.stage % 2 == 0
|
||||
)
|
||||
# Add input_obj and output_obj to end of list.
|
||||
input_objs.append(input_obj)
|
||||
output_objs.append(output_obj)
|
||||
|
||||
if not last_iteration:
|
||||
input_obj = self.recv_forward()
|
||||
else:
|
||||
# TODO adjust here
|
||||
self.send_forward(output_obj)
|
||||
output_obj_grad = self.recv_backward()
|
||||
# Pop output_obj and output_obj from the start of the list for
|
||||
# the backward pass.
|
||||
input_obj = input_objs.pop(0)
|
||||
output_obj = output_objs.pop(0)
|
||||
input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad)
|
||||
|
||||
# Add input_obj and output_obj to end of list.
|
||||
input_objs.append(input_obj)
|
||||
output_objs.append(output_obj)
|
||||
|
||||
# Pop output_obj and output_obj from the start of the list for
|
||||
# the backward pass.
|
||||
input_obj = input_objs.pop(0)
|
||||
output_obj = output_objs.pop(0)
|
||||
input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad)
|
||||
|
||||
if last_iteration:
|
||||
input_obj = None
|
||||
else:
|
||||
input_obj = self.recv_forward()
|
||||
if last_iteration:
|
||||
self.send_backward(input_obj_grad)
|
||||
else:
|
||||
input_obj = self.send_backward_recv_forward(
|
||||
input_obj_grad, send_prior_fallback=self.stage_manager.stage % 2 == 0
|
||||
)
|
||||
|
||||
# Run cooldown backward passes.
|
||||
if not forward_only:
|
||||
for i in range(num_warmup_microbatches):
|
||||
input_obj = input_objs.pop(0)
|
||||
output_obj = output_objs.pop(0)
|
||||
for i in range(num_warmup_microbatches):
|
||||
input_obj = input_objs.pop(0)
|
||||
output_obj = output_objs.pop(0)
|
||||
|
||||
output_obj_grad = self.recv_backward()
|
||||
input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad)
|
||||
self.send_backward(input_obj_grad)
|
||||
output_obj_grad = self.recv_backward()
|
||||
input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad)
|
||||
self.send_backward(input_obj_grad)
|
||||
|
||||
assert all(len(v) == 0 for v in input_objs) and all(len(v) == 0 for v in output_objs)
|
||||
|
||||
if outputs is not None:
|
||||
if isinstance(model, ModelWrapper):
|
||||
model = model.unwrap()
|
||||
outputs = merge_batch(outputs, getattr(model, "batch_size_dim", 0))
|
||||
return {"loss": accum_loss, "outputs": outputs}
|
||||
|
||||
def forward_backward_step(
|
||||
self,
|
||||
model: Module,
|
||||
data_iter: Iterable,
|
||||
criterion: Callable[..., Any],
|
||||
optimizer: Optional[OptimizerWrapper] = None,
|
||||
return_loss: bool = False,
|
||||
return_outputs: bool = False,
|
||||
) -> dict:
|
||||
"""
|
||||
Args:
|
||||
model (Module): Model to be trained.
|
||||
data_iter (Iterable): Data iterator.
|
||||
criterion (Callable[[Any, Any], Tensor]): Criterion to be used. It should take two arguments: model outputs and inputs, and returns loss tensor.
|
||||
optimizer (OptimizerWrapper, optional): Optimizer to be used. Can be None when only forward is executed. Defaults to None.
|
||||
return_loss (bool, optional): Whether to return loss. Defaults to False. Whether to return loss.
|
||||
return_outputs (bool, optional): Whether to return model outputs. Defaults to False. Whether to return model outputs.
|
||||
|
||||
Returns:
|
||||
dict: Dictionary containing loss and outputs.
|
||||
"""
|
||||
|
||||
self.forward_only = not torch.is_grad_enabled()
|
||||
if optimizer is None:
|
||||
assert self.forward_only, "Optimizer should be passed when doing backward."
|
||||
|
||||
if self.forward_only:
|
||||
result = self.run_forward_only(model, data_iter, criterion, return_loss, return_outputs)
|
||||
else:
|
||||
result = self.run_forward_backward(model, data_iter, criterion, optimizer, return_loss, return_outputs)
|
||||
|
||||
return result
|
||||
|
@@ -1,3 +1,4 @@
|
||||
import contextlib
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch.distributed as dist
|
||||
@@ -19,7 +20,15 @@ class PipelineStageManager:
|
||||
stage (int): The current stage.
|
||||
"""
|
||||
|
||||
def __init__(self, pg_mesh: ProcessGroupMesh, pipeline_axis: int, is_virtual: bool = False) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
pg_mesh: ProcessGroupMesh,
|
||||
pipeline_axis: int,
|
||||
enable_interleave: bool = False,
|
||||
num_model_chunks: int = 1,
|
||||
) -> None:
|
||||
assert enable_interleave or num_model_chunks == 1, "num_model_chunks must be 1 when enable_interleave is False"
|
||||
|
||||
self.pg_mesh = pg_mesh
|
||||
self.pipeline_axis = pipeline_axis
|
||||
self.prev_rank: Optional[Tuple[int, ...]] = None
|
||||
@@ -43,29 +52,56 @@ class PipelineStageManager:
|
||||
ranks_in_group = self.pg_mesh.get_ranks_in_group(group)
|
||||
self.p2p_groups[tuple(ranks_in_group)] = group
|
||||
|
||||
if is_virtual:
|
||||
self.is_interleave = enable_interleave
|
||||
if enable_interleave:
|
||||
# use circle p2p communication
|
||||
# add the process group of the first rank and the last rank
|
||||
# only used in interleaved pipeline for now
|
||||
group = self.pg_mesh.get_group_along_axis(self.pipeline_axis, [stages[0], stages[-1]])
|
||||
if self.stage in [stages[0], stages[-1]]:
|
||||
ranks_in_group = self.pg_mesh.get_ranks_in_group(group)
|
||||
self.p2p_groups[tuple(ranks_in_group)] = group
|
||||
|
||||
def is_first_stage(self) -> bool:
|
||||
# for interleaved pipeline parallel, each device is responsible for multiple chunk of layers
|
||||
self.num_model_chunks: int = num_model_chunks
|
||||
|
||||
# for shardformer, hold stage indices of model
|
||||
self.stage_indices: List[Tuple[int, int]]
|
||||
# for shardformer, hold model chunk id
|
||||
self.model_chunk_id: Optional[int] = None
|
||||
|
||||
def is_first_stage(self, ignore_chunk: bool = False) -> bool:
|
||||
"""Is the current stage the first stage.
|
||||
|
||||
NOTE:
|
||||
1. if using interleaved pipeline parallel, the first stage is the first chunk of the first device.
|
||||
2. invoke is_first_stage() with ignore_chunk=True is equivalent to invoke is_first_device()
|
||||
|
||||
Returns:
|
||||
bool: Whether the current stage is the first stage.
|
||||
"""
|
||||
return self.stage == 0
|
||||
assert isinstance(ignore_chunk, bool)
|
||||
assert not self.is_interleave or (ignore_chunk or self.model_chunk_id is not None)
|
||||
if not self.is_interleave or ignore_chunk:
|
||||
return self.stage == 0
|
||||
else:
|
||||
return self.stage == 0 and self.model_chunk_id == 0
|
||||
|
||||
def is_last_stage(self) -> bool:
|
||||
def is_last_stage(self, ignore_chunk: bool = False) -> bool:
|
||||
"""Is the current stage the last stage.
|
||||
|
||||
NOTE:
|
||||
1. if using interleaved pipeline parallel, the last stage is the last chunk of the last device.
|
||||
2. invoke is_last_stage() with ignore_chunk=True is equivalent to invoke is_last_device()
|
||||
|
||||
Returns:
|
||||
bool: Whether the current stage is the last stage.
|
||||
"""
|
||||
return self.stage == self.num_stages - 1
|
||||
assert isinstance(ignore_chunk, bool)
|
||||
assert not self.is_interleave or (ignore_chunk or self.model_chunk_id is not None)
|
||||
if not self.is_interleave or ignore_chunk:
|
||||
return self.stage == self.num_stages - 1
|
||||
else:
|
||||
return self.stage == self.num_stages - 1 and self.model_chunk_id == self.num_model_chunks - 1
|
||||
|
||||
@property
|
||||
def num_stages(self) -> int:
|
||||
@@ -133,3 +169,10 @@ class PipelineStageManager:
|
||||
ProcessGroup: Process group of the given stages.
|
||||
"""
|
||||
return self.pg_mesh.get_group_along_axis(self.pipeline_axis, stages)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def switch_model_chunk_id(self, model_chunk_id: int):
|
||||
old_model_chunk_id = self.model_chunk_id
|
||||
self.model_chunk_id = model_chunk_id
|
||||
yield
|
||||
self.model_chunk_id = old_model_chunk_id
|
||||
|
Reference in New Issue
Block a user