Merge branch 'main' into sync/npu

This commit is contained in:
ver217
2024-01-18 12:05:21 +08:00
152 changed files with 8641 additions and 2138 deletions

View File

@@ -4,23 +4,20 @@
import io
import pickle
import re
from typing import Any, List, Optional, Union
from collections import namedtuple
from typing import Any, Callable, List, Optional, Tuple, Union
import torch
import torch.distributed as dist
from dataclasses import dataclass
from enum import Enum
from packaging.version import Version
from torch.distributed import ProcessGroup
from torch.distributed import distributed_c10d as c10d
from torch.utils._pytree import tree_flatten, tree_unflatten
from .stage_manager import PipelineStageManager
_unpickler = pickle.Unpickler
def _cuda_safe_tensor_to_object(tensor: torch.Tensor, tensor_size: torch.Size) -> object:
def _cuda_safe_tensor_to_object(tensor: torch.Tensor, tensor_size: torch.Size) -> Any:
"""transform tensor to object with unpickle.
Info of the device in bytes stream will be modified into current device before unpickling
@@ -42,27 +39,13 @@ def _cuda_safe_tensor_to_object(tensor: torch.Tensor, tensor_size: torch.Size) -
buf = bytes(buf_array)
io_bytes = io.BytesIO(buf)
byte_pickler = _unpickler(io_bytes)
byte_pickler = pickle.Unpickler(io_bytes)
unpickle = byte_pickler.load()
return unpickle
def check_for_nccl_backend(group):
pg = group or c10d._get_default_group()
# Gate PG wrapper check on Gloo availability.
if c10d._GLOO_AVAILABLE:
# It is not expected for PG to be wrapped many times, but support it just
# in case
while isinstance(pg, c10d._ProcessGroupWrapper):
pg = pg.wrapped_pg
return (
c10d.is_nccl_available() and
pg.name() == c10d.Backend.NCCL
)
# NOTE: FIXME: NPU DOES NOT support isend nor irecv, so broadcast is kept for future use
def _broadcast_object_list(
object_list: List[Any], src: int, group: ProcessGroup, device: Optional[Union[torch.device, str, int]] = None
):
@@ -70,20 +53,18 @@ def _broadcast_object_list(
The only difference is that object will be move to correct device after unpickled.
If local_rank = src, then object list will be sent to rank src. Otherwise, object list will
be updated with data sent from rank src.
Args:
object_list (List[Any]): list of object to broadcast
src (int): source rank to broadcast
dst (int): dst rank to broadcast
device (:class:`torch.device`): device to do broadcast. current device in default
"""
if c10d._rank_not_in_group(group):
c10d._warn_not_in_group("broadcast_object_list")
return
is_nccl_backend = check_for_nccl_backend(group)
is_nccl_backend = _check_for_nccl_backend(group)
current_device = None
if device is not None:
@@ -131,7 +112,7 @@ def _broadcast_object_list(
if my_rank != src:
for i, obj_size in enumerate(object_sizes_tensor):
obj_view = object_tensor[offset: offset + obj_size]
obj_view = object_tensor[offset : offset + obj_size]
obj_view = obj_view.type(torch.uint8)
if obj_view.device != torch.device("cpu"):
obj_view = obj_view.cpu()
@@ -149,80 +130,107 @@ def _broadcast_object_list(
object_list[i] = unpickle_object
def check_device(group):
is_nccl_backend = check_for_nccl_backend(group)
current_device = None
def _check_for_nccl_backend(group):
pg = group or c10d._get_default_group()
# Gate PG wrapper check on Gloo availability.
if c10d._GLOO_AVAILABLE:
# It is not expected for PG to be wrapped many times, but support it just in case
while isinstance(pg, c10d._ProcessGroupWrapper):
pg = pg.wrapped_pg
return c10d.is_nccl_available() and pg.name() == c10d.Backend.NCCL
def _check_device(group):
is_nccl_backend = _check_for_nccl_backend(group)
current_device = torch.device("cpu")
if is_nccl_backend:
current_device = torch.device("cuda", torch.cuda.current_device())
return current_device, is_nccl_backend
TensorMetadata = namedtuple('TensorMetadata', ['key', 'shape', 'dtype', 'requires_grad'])
TensorMetadata = namedtuple("TensorMetadata", ["shape", "dtype", "requires_grad"])
P2PMetadata = namedtuple("P2PMetadata", ["tree_spec", "tensor_metadata", "non_tensor_obj_idx", "non_tensor_objs"])
class P2PDataType(Enum):
serialization = 0
tensor = 1
list = 2
dict = 3
def create_send_metadata(
object: Any, strict: bool = True, return_tensor: bool = False
) -> Union[P2PMetadata, Tuple[P2PMetadata, List[torch.Tensor]]]:
"""
Args:
object (Any): object needed to be sent
strict (bool, optional): whether to check if the object is supported for fast send
return_tensor (bool, optional): whether to return tensor objects
"""
objs, tree_spec = tree_flatten(object)
tensor_metadata, tensor_objs = [], []
non_tensor_obj_idx, non_tensor_objs = [], []
for idx, obj in enumerate(objs):
if isinstance(obj, torch.Tensor):
tensor_objs.append(obj)
tensor_metadata.append(TensorMetadata(obj.shape, obj.dtype, obj.requires_grad))
else:
non_tensor_obj_idx.append(idx)
non_tensor_objs.append(obj)
assert not strict or len(non_tensor_objs) == 0, "Only support tensor for fast send"
metadata = P2PMetadata(tree_spec, tensor_metadata, non_tensor_obj_idx, non_tensor_objs)
return metadata if not return_tensor else (metadata, tensor_objs)
@dataclass
class P2PMetadata:
data_type: P2PDataType
content: Union[List[TensorMetadata], TensorMetadata, Any]
def filling_ops_queue(obj, comm_op, comm_rank, ops_queue, group):
def _filling_ops_queue(
obj: Union[torch.Tensor, List[torch.Tensor]],
comm_op: Callable,
comm_rank: int,
ops_queue: List,
group: ProcessGroup,
):
if isinstance(obj, torch.Tensor):
obj = obj.contiguous()
op_to_add = dist.P2POp(comm_op, obj, comm_rank, group)
ops_queue.append(op_to_add)
else:
for tensor_to_comm in obj:
tensor_to_comm = tensor_to_comm.contiguous()
op_to_add = dist.P2POp(comm_op, tensor_to_comm, comm_rank, group)
ops_queue.append(op_to_add)
assert isinstance(tensor_to_comm, torch.Tensor)
_filling_ops_queue(tensor_to_comm, comm_op, comm_rank, ops_queue, group)
def create_recv_buffer(p2p_metadata: P2PMetadata, current_device):
if p2p_metadata.data_type == P2PDataType.tensor:
metadata = p2p_metadata.content
tensor_recv = torch.empty(metadata.shape, requires_grad=metadata.requires_grad, device=current_device, dtype=metadata.dtype)
return tensor_recv
elif p2p_metadata.data_type in (P2PDataType.list, P2PDataType.dict):
buffer_recv = []
for metadata in p2p_metadata.content:
tensor_recv = torch.empty(metadata.shape, requires_grad=metadata.requires_grad, device=current_device, dtype=metadata.dtype)
buffer_recv.append(tensor_recv)
return buffer_recv
else:
raise ValueError(f"Unknown data_type: {p2p_metadata.data_type}")
def _create_recv_buffer(tensor_metadata: List[TensorMetadata], current_device) -> List[torch.Tensor]:
buffer_recv = []
for metadata in tensor_metadata:
tensor_recv = torch.empty(
metadata.shape, requires_grad=metadata.requires_grad, device=current_device, dtype=metadata.dtype
)
buffer_recv.append(tensor_recv)
return buffer_recv
def _batch_send_recv_tensor(send_tensor_list, recv_tensor_metadata, send_dst, recv_src, send_group, recv_group, current_device):
def _batch_send_recv_tensor(
send_tensor_list: Optional[List[torch.Tensor]],
recv_tensor_metadata: Optional[List[TensorMetadata]],
send_dst: Optional[int],
recv_src: Optional[int],
send_group: Optional[ProcessGroup],
recv_group: Optional[ProcessGroup],
current_device: Any,
) -> Optional[Union[torch.Tensor, List[torch.Tensor]]]:
buffer_recv = None
if recv_tensor_metadata is not None:
buffer_recv = create_recv_buffer(recv_tensor_metadata, current_device)
buffer_recv = _create_recv_buffer(recv_tensor_metadata, current_device)
ops = []
if send_dst is not None:
filling_ops_queue(send_tensor_list, dist.isend, send_dst, ops, send_group)
if recv_src is not None:
assert buffer_recv is not None
filling_ops_queue(buffer_recv, dist.irecv, recv_src, ops, recv_group)
if send_dst is not None and send_tensor_list is not None:
assert send_group is not None
_filling_ops_queue(send_tensor_list, dist.isend, send_dst, ops, send_group)
if recv_src is not None and buffer_recv is not None:
assert recv_group is not None
_filling_ops_queue(buffer_recv, dist.irecv, recv_src, ops, recv_group)
if len(ops) > 0:
reqs = dist.batch_isend_irecv(ops)
for req in reqs:
req.wait()
torch.cuda.synchronize()
# Remove synchronization according to Pytorch's documentation
# However, the Megatron-LM does synchronization here
# https://github.com/microsoft/Megatron-DeepSpeed/blob/ef13d099c2a1609225a4ce4c1a1753cc76dd90a1/megatron/p2p_communication.py#L111-L112
@@ -233,12 +241,16 @@ def _batch_send_recv_tensor(send_tensor_list, recv_tensor_metadata, send_dst, re
def _send_recv_serialization_object(
object: Any,
send_dst: Optional[int], recv_src: Optional[int],
send_group: Optional[ProcessGroup], recv_group: Optional[ProcessGroup],
current_device,
is_nccl_backend):
object: Optional[P2PMetadata],
send_dst: Optional[int],
recv_src: Optional[int],
send_group: Optional[ProcessGroup],
recv_group: Optional[ProcessGroup],
current_device: Any,
is_nccl_backend: bool,
) -> Optional[P2PMetadata]:
ops = []
send_object_tensor = None
if object is not None and send_dst is not None:
if Version(torch.__version__) >= Version("1.13.0"):
@@ -250,44 +262,40 @@ def _send_recv_serialization_object(
send_object_size_tensor = send_object_size_tensor.to(current_device)
send_object_tensor = send_object_tensor.to(current_device)
filling_ops_queue(send_object_size_tensor, dist.isend, send_dst, ops, send_group)
_filling_ops_queue(send_object_size_tensor, dist.isend, send_dst, ops, send_group)
recv_object_size_tensor = None
if recv_src is not None:
recv_object_size_tensor = torch.empty(1, dtype=torch.long)
if is_nccl_backend:
recv_object_size_tensor = recv_object_size_tensor.to(current_device)
filling_ops_queue(recv_object_size_tensor, dist.irecv, recv_src, ops, recv_group)
_filling_ops_queue(recv_object_size_tensor, dist.irecv, recv_src, ops, recv_group)
if len(ops) > 0:
reqs = dist.batch_isend_irecv(ops)
for req in reqs:
req.wait()
torch.cuda.synchronize()
# See the comment in `_batch_send_recv_tensor`
# torch.cuda.synchronize()
ops = []
if send_dst is not None and send_object_tensor is not None:
filling_ops_queue(send_object_tensor, dist.isend, send_dst, ops, send_group)
_filling_ops_queue(send_object_tensor, dist.isend, send_dst, ops, send_group)
recv_object_tensor = None
if recv_src is not None and recv_object_size_tensor is not None:
recv_object_tensor = torch.empty(recv_object_size_tensor.item(), dtype=torch.uint8)
if is_nccl_backend:
recv_object_tensor = recv_object_tensor.to(current_device)
filling_ops_queue(recv_object_tensor, dist.irecv, recv_src, ops, recv_group)
_filling_ops_queue(recv_object_tensor, dist.irecv, recv_src, ops, recv_group)
if len(ops) > 0:
reqs = dist.batch_isend_irecv(ops)
for req in reqs:
req.wait()
torch.cuda.synchronize()
# See the comment in `_batch_send_recv_tensor`
# torch.cuda.synchronize()
@@ -296,112 +304,119 @@ def _send_recv_serialization_object(
if recv_object_tensor.device != torch.device("cpu"):
recv_object_tensor = recv_object_tensor.cpu()
unpickle_object = _cuda_safe_tensor_to_object(
recv_object_tensor, recv_object_size_tensor.item())
unpickle_object = _cuda_safe_tensor_to_object(recv_object_tensor, recv_object_size_tensor.item())
if (
isinstance(unpickle_object, torch.Tensor)
and unpickle_object.device.index != torch.cuda.current_device()
):
if isinstance(unpickle_object, torch.Tensor) and unpickle_object.device.index != torch.cuda.current_device():
unpickle_object = unpickle_object.cuda()
return unpickle_object
def _check_if_fast_send_available(object):
if type(object) is torch.Tensor:
return True
elif type(object) is list:
is_list_of_tensor = all([type(v) is torch.Tensor for v in object])
return is_list_of_tensor
elif type(object) is dict:
is_dict_of_tensor = all([type(k) is str and type(
v) is torch.Tensor for k, v in object.items()])
return is_dict_of_tensor
return False
def _communicate(
object,
object: Any,
send_dst: Optional[int],
recv_src: Optional[int],
send_group: Optional[ProcessGroup] = None,
recv_group: Optional[ProcessGroup] = None,
send_metadata: bool = True,
metadata_recv: Optional[P2PMetadata] = None,
send_prior_fallback: Optional[bool] = None,
) -> Any:
if c10d._rank_not_in_group(send_group) or c10d._rank_not_in_group(recv_group):
c10d._warn_not_in_group("_communicate")
return
"""
Send and receive object from send_dst and recv_src respectively
current_send_device, is_send_nccl_backend = check_device(send_group)
current_recv_device, is_recv_nccl_backend = check_device(recv_group)
Args:
object (Any): object needed to be sent
send_dst (int): rank of the destination
recv_src (int): rank of the source
send_group (ProcessGroup, optional): process group of sender
recv_group (ProcessGroup, optional): process group of receiver
send_metadata (bool, optional): whether to send metadata
metadata_recv (P2PMetadata, optional): metadata of the object to be received
"""
assert send_dst is not None or recv_src is not None, "send_dst and recv_src cannot be both None"
assert send_dst is None or send_group is not None, "send_group must be specified when send_dst is not None"
assert recv_src is None or recv_group is not None, "recv_group must be specified when recv_src is not None"
assert (
metadata_recv is None or len(metadata_recv.non_tensor_obj_idx) == 0
), "metadata_recv should not contain non-tensor objects"
metadata_send, tensor_objs = None, None
if object is not None:
# NOTE: if object contains non-tensor objects, we have to send metadata
metadata_send, tensor_objs = create_send_metadata(object, strict=False, return_tensor=True)
send_metadata = send_metadata or len(metadata_send.non_tensor_obj_idx) > 0
# NOTE: send & recv should be atomic operations. However, if we need to send metadata or receive metadata,
# we are not able to do that (1. send & recv metadata 2. send & recv). So we need to split the send & recv into two parts in this case.
if (send_dst is not None and recv_src is not None) and (send_metadata or metadata_recv is None):
assert send_prior_fallback is not None, "Priority must be set if fallback happens"
if send_prior_fallback:
_communicate(object, send_dst=send_dst, recv_src=None, send_group=send_group, send_metadata=send_metadata)
return _communicate(
None, send_dst=None, recv_src=recv_src, recv_group=recv_group, metadata_recv=metadata_recv
)
else:
recv_data = _communicate(
None, send_dst=None, recv_src=recv_src, recv_group=recv_group, metadata_recv=metadata_recv
)
_communicate(object, send_dst=send_dst, recv_src=None, send_group=send_group, send_metadata=send_metadata)
return recv_data
# NOTE: only the following 5 cases are valid:
# 1. send() [needs extra metadata] and no recv()
# 2. recv() [needs extra metadata] and no send()
# 3. neither send() nor recv() need extra metadata
assert not (send_dst is not None and send_metadata) or recv_src is None
assert not (recv_src is not None and metadata_recv is None) or send_dst is None
assert not (send_dst is not None and recv_src is not None) or (not send_metadata and metadata_recv is not None)
assert not c10d._rank_not_in_group(send_group) and not c10d._rank_not_in_group(recv_group)
current_send_device, is_send_nccl_backend = _check_device(send_group)
current_recv_device, is_recv_nccl_backend = _check_device(recv_group)
is_nccl_backend = is_send_nccl_backend and is_recv_nccl_backend
assert current_send_device == current_recv_device
current_device = current_send_device
assert (send_dst is not None) or (recv_src is not None)
if (send_dst is not None and send_metadata) or (recv_src is not None and metadata_recv is None):
# Send and receive metadata
_metadata_recv = _send_recv_serialization_object(
object=metadata_send,
send_dst=send_dst if send_metadata else None,
recv_src=recv_src if metadata_recv is None else None,
send_group=send_group if send_metadata else None,
recv_group=recv_group if metadata_recv is None else None,
current_device=current_device,
is_nccl_backend=is_nccl_backend,
)
assert metadata_recv is None or _metadata_recv is None
metadata_recv = _metadata_recv if metadata_recv is None else metadata_recv
can_fast_send = False
send_metadata = None
if send_dst is not None:
can_fast_send = _check_if_fast_send_available(object) and is_nccl_backend
if not can_fast_send:
send_metadata = P2PMetadata(P2PDataType.serialization, object)
else:
if type(object) is torch.Tensor:
data_type = P2PDataType.tensor
content = TensorMetadata(None, object.shape, object.dtype, object.requires_grad)
elif type(object) is list:
data_type = P2PDataType.list
content = []
for v in object:
content.append(TensorMetadata(None, v.shape, v.dtype, v.requires_grad))
elif type(object) is dict:
data_type = P2PDataType.dict
content = []
for k, v in object.items():
content.append(TensorMetadata(k, v.shape, v.dtype, v.requires_grad))
else:
raise ValueError('Cannot send object of type {}'.format(type(object)))
send_metadata = P2PMetadata(data_type, content)
# Send and receive data
recv_tensor_metadata = None if metadata_recv is None else metadata_recv.tensor_metadata
recv_tensor_objs = _batch_send_recv_tensor(
tensor_objs, recv_tensor_metadata, send_dst, recv_src, send_group, recv_group, current_device
)
recv_metadata = _send_recv_serialization_object(send_metadata, send_dst, recv_src, send_group, recv_group, current_device, is_nccl_backend)
if recv_metadata is not None:
assert type(recv_metadata) is P2PMetadata
if recv_metadata.data_type == P2PDataType.serialization:
return recv_metadata.content
if not can_fast_send and send_dst is not None:
return
if metadata_recv is not None:
assert isinstance(metadata_recv, P2PMetadata)
tree_spec = metadata_recv.tree_spec
non_tensor_obj_idx = metadata_recv.non_tensor_obj_idx
non_tensor_objs = metadata_recv.non_tensor_objs
send_tensor_list = None
if type(object) is torch.Tensor:
send_tensor_list = object
elif type(object) is list:
send_tensor_list = object
elif type(object) is dict:
send_tensor_list = list(object.values())
if recv_tensor_objs is None:
recv_tensor_objs = []
recv_buffer = _batch_send_recv_tensor(send_tensor_list, recv_metadata, send_dst, recv_src, send_group, recv_group, current_device)
for idx in non_tensor_obj_idx:
recv_tensor_objs.insert(idx, non_tensor_objs.pop(0))
recv_object = tree_unflatten(recv_tensor_objs, tree_spec)
if recv_metadata is not None:
assert recv_buffer is not None
if recv_metadata.data_type in [P2PDataType.tensor, P2PDataType.list]:
return recv_buffer
elif recv_metadata.data_type == P2PDataType.dict:
return {
k: v
for k, v in zip(
[m.key for m in recv_metadata.content],
recv_buffer,
)
}
else:
raise ValueError('Unknown data type {}'.format(recv_metadata.data_type))
return recv_object
def _send_object(object: Any, src: int, dst: int, group: ProcessGroup) -> None:
def _send_object(object: Any, src: int, dst: int, group: ProcessGroup, **kwargs) -> None:
"""send anything to dst rank
Args:
@@ -411,10 +426,10 @@ def _send_object(object: Any, src: int, dst: int, group: ProcessGroup) -> None:
Returns:
None
"""
_communicate(object, send_dst=dst, recv_src=None, send_group=group)
_communicate(object, send_dst=dst, recv_src=None, send_group=group, **kwargs)
def _recv_object(src: int, dst: int, group: ProcessGroup) -> Any:
def _recv_object(src: int, dst: int, group: ProcessGroup, **kwargs) -> Any:
"""recv anything from src
Args:
@@ -423,7 +438,7 @@ def _recv_object(src: int, dst: int, group: ProcessGroup) -> Any:
Returns:
Any: Object received from src.
"""
return _communicate(None, send_dst=None, recv_src=src, recv_group=group)
return _communicate(None, send_dst=None, recv_src=src, recv_group=group, **kwargs)
def _p2p_comm(
@@ -436,7 +451,7 @@ def _p2p_comm(
"""
Send and recv tensor using P2P communication, used when pipeline size is 2 to solve the race communication.
Agrs:
Args:
tensor_send_next (torch.Tensor): tensor to be sent to next stage
recv_prev (bool): whether to receive tensor from previous stage
peer (int): rank of the peer
@@ -467,7 +482,6 @@ def _p2p_comm(
group=group,
)
ops.append(recv_prev_op)
if len(ops) > 0:
reqs = dist.batch_isend_irecv(ops)
for req in reqs:
@@ -490,7 +504,6 @@ def _p2p_comm(
group=group,
)
ops.append(send_next_op)
if tensor_recv_prev is not None:
recv_prev_op = dist.P2POp(
dist.irecv,
@@ -510,7 +523,7 @@ class PipelineP2PCommunication:
def __init__(self, stage_manager: PipelineStageManager) -> None:
self.stage_manager = stage_manager
def recv_forward(self, prev_rank: int = None) -> Any:
def recv_forward(self, prev_rank: Optional[int] = None, metadata_recv: Optional[P2PMetadata] = None) -> Any:
"""Copy the forward output from the previous stage in pipeline as the input tensor of this stage.
Args:
@@ -522,11 +535,16 @@ class PipelineP2PCommunication:
if prev_rank is None:
prev_rank = self.stage_manager.get_prev_rank()
cur_rank = self.stage_manager.get_rank()
input_tensor = _recv_object(prev_rank, cur_rank, self.stage_manager.get_p2p_process_group(prev_rank, cur_rank))
input_tensor = _recv_object(
prev_rank,
cur_rank,
self.stage_manager.get_p2p_process_group(prev_rank, cur_rank),
metadata_recv=metadata_recv,
)
return input_tensor
def recv_backward(self, next_rank: int = None) -> Any:
def recv_backward(self, next_rank: Optional[int] = None, metadata_recv: Optional[P2PMetadata] = None) -> Any:
"""Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage.
Args:
@@ -539,12 +557,15 @@ class PipelineP2PCommunication:
next_rank = self.stage_manager.get_next_rank()
cur_rank = self.stage_manager.get_rank()
output_tensor_grad = _recv_object(
next_rank, cur_rank, self.stage_manager.get_p2p_process_group(next_rank, cur_rank)
next_rank,
cur_rank,
self.stage_manager.get_p2p_process_group(next_rank, cur_rank),
metadata_recv=metadata_recv,
)
return output_tensor_grad
def send_forward(self, output_object: Any, next_rank: int = None) -> None:
def send_forward(self, output_object: Any, next_rank: Optional[int] = None, send_metadata: bool = True) -> None:
"""Sends the input tensor to the next stage in pipeline.
Args:
@@ -554,9 +575,15 @@ class PipelineP2PCommunication:
if next_rank is None:
next_rank = self.stage_manager.get_next_rank()
cur_rank = self.stage_manager.get_rank()
_send_object(output_object, cur_rank, next_rank, self.stage_manager.get_p2p_process_group(cur_rank, next_rank))
_send_object(
output_object,
cur_rank,
next_rank,
self.stage_manager.get_p2p_process_group(cur_rank, next_rank),
send_metadata=send_metadata,
)
def send_backward(self, input_object: Any, prev_rank: int = None) -> None:
def send_backward(self, input_object: Any, prev_rank: Optional[int] = None, send_metadata: bool = True) -> None:
"""Sends the gradient tensor to the previous stage in pipeline.
Args:
@@ -566,9 +593,22 @@ class PipelineP2PCommunication:
if prev_rank is None:
prev_rank = self.stage_manager.get_prev_rank()
cur_rank = self.stage_manager.get_rank()
_send_object(input_object, cur_rank, prev_rank, self.stage_manager.get_p2p_process_group(cur_rank, prev_rank))
_send_object(
input_object,
cur_rank,
prev_rank,
self.stage_manager.get_p2p_process_group(cur_rank, prev_rank),
send_metadata=send_metadata,
)
def send_forward_recv_backward(self, input_object: Any, next_rank: int = None) -> Any:
def send_forward_recv_backward(
self,
input_object: Any,
next_rank: Optional[int] = None,
send_metadata: bool = True,
metadata_recv: Optional[P2PMetadata] = None,
send_prior_fallback: Optional[bool] = None,
) -> Any:
"""Sends the gradient tensor to and copy the gradient tensor from the next stage in pipeline
Args:
@@ -581,11 +621,24 @@ class PipelineP2PCommunication:
cur_rank = self.stage_manager.get_rank()
group = self.stage_manager.get_p2p_process_group(cur_rank, next_rank)
return _communicate(
input_object, next_rank, next_rank,
send_group=group, recv_group=group,
input_object,
next_rank,
next_rank,
send_group=group,
recv_group=group,
send_metadata=send_metadata,
metadata_recv=metadata_recv,
send_prior_fallback=send_prior_fallback,
)
def send_backward_recv_forward(self, input_object: Any, prev_rank: int = None) -> Any:
def send_backward_recv_forward(
self,
input_object: Any,
prev_rank: Optional[int] = None,
send_metadata: bool = True,
metadata_recv: Optional[P2PMetadata] = None,
send_prior_fallback: Optional[bool] = None,
) -> Any:
"""Sends the gradient tensor to and copy the gradient tensor from the previous stage in pipeline
Args:
@@ -597,37 +650,23 @@ class PipelineP2PCommunication:
cur_rank = self.stage_manager.get_rank()
group = self.stage_manager.get_p2p_process_group(prev_rank, cur_rank)
return _communicate(
input_object, prev_rank, prev_rank,
send_group=group, recv_group=group,
)
def send_forward_recv_forward(self, input_object: Any, prev_rank: int = None, next_rank: int = None) -> Any:
"""Sends the gradient tensor to the previous stage and copy the input tensor from the previous stage in pipeline.
Args:
input_object (Any): Object to be sent.
prev_rank (int, optional): The rank of the sender of the tensor
next_rank (int, optional): The rank of the recipient of the tensor
"""
if prev_rank is None:
prev_rank = self.stage_manager.get_prev_rank()
if next_rank is None:
next_rank = self.stage_manager.get_next_rank()
cur_rank = self.stage_manager.get_rank()
recv_group = self.stage_manager.get_p2p_process_group(prev_rank, cur_rank)
send_group = self.stage_manager.get_p2p_process_group(cur_rank, next_rank)
return _communicate(
input_object,
send_dst=next_rank,
recv_src=prev_rank,
send_group=send_group,
recv_group=recv_group,
prev_rank,
prev_rank,
send_group=group,
recv_group=group,
send_metadata=send_metadata,
metadata_recv=metadata_recv,
send_prior_fallback=send_prior_fallback,
)
def p2p_communicate(
self, output_object: Any, recv_pre: bool, peer: int = None, comm_dtype: torch.dtype = torch.float16
self,
output_object: Any,
recv_pre: bool,
next_rank: Optional[int] = None,
comm_dtype: torch.dtype = torch.float16,
) -> None:
"""
Sends the input tensor to the next stage in pipeline, using `P2Pop` in torch.
@@ -636,10 +675,14 @@ class PipelineP2PCommunication:
output_object (Any): Object to be sent.
next_rank (int, optional): The rank of the recipient of the tensor.
"""
if peer is None:
peer = self.stage_manager.get_next_rank()
if next_rank is None:
next_rank = self.stage_manager.get_next_rank()
cur_rank = self.stage_manager.get_rank()
recv_tensor = _p2p_comm(
output_object, recv_pre, peer, self.stage_manager.get_p2p_process_group(cur_rank, peer), comm_dtype
output_object,
recv_pre,
next_rank,
self.stage_manager.get_p2p_process_group(cur_rank, next_rank),
comm_dtype,
)
return recv_tensor

View File

@@ -1,14 +1,14 @@
from functools import partial
from typing import Any, Callable, Iterable, List, Optional, Union
from typing import Any, Callable, Dict, Iterable, List, Optional, Union
import torch
import torch.cuda
from torch.nn import Module
from torch.nn import Module, ModuleList
from torch.utils._pytree import tree_map
from colossalai.accelerator import get_accelerator
from colossalai.interface import OptimizerWrapper
from colossalai.pipeline.p2p import PipelineP2PCommunication
from colossalai.pipeline.p2p import PipelineP2PCommunication, create_send_metadata
from colossalai.pipeline.stage_manager import PipelineStageManager
from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, model_forward, retain_grad, to_device
@@ -16,18 +16,35 @@ from .base import PipelineSchedule
class InterleavedSchedule(PipelineSchedule):
def __init__(self, num_microbatches: int, num_model_chunks: int, stage_manager: PipelineStageManager) -> None:
self.num_model_chunks = num_model_chunks
assert (
num_microbatches % self.num_model_chunks == 0
), "Number of microbatches should be an integer multiple of number of model chunks"
def __init__(
self,
stage_manager: PipelineStageManager,
num_model_chunks: int,
num_microbatch: Optional[int] = None,
microbatch_size: Optional[int] = None,
enable_metadata_cache: bool = True,
) -> None:
super().__init__(stage_manager)
assert (
num_microbatch is not None or microbatch_size is not None
), "Either num_microbatch or microbatch_size should be provided"
self.comm = PipelineP2PCommunication(stage_manager)
self.num_microbatches = num_microbatches
self.batch: Optional[Any] = None
self.batch_size: Optional[int] = None
self.microbatch_offset: Optional[int] = None
self.microbatch_size: Optional[int] = None
self.num_microbatch = num_microbatch
self.microbatch_size = microbatch_size
self.num_model_chunks = num_model_chunks
self.batch: Any
self.batch_size: int
self.last_batch_size: Optional[int] = None
self.microbatch_offset: List[int]
# P2PMeta cache
self.enable_metadata_cache = enable_metadata_cache
self.send_tensor_metadata = True
self.send_grad_metadata = True
self.tensor_metadata_recv = None
self.grad_metadata_recv = None
def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None:
"""Load a batch from data iterator.
@@ -39,11 +56,37 @@ class InterleavedSchedule(PipelineSchedule):
batch = next(data_iter)
if device is not None:
batch = tree_map(partial(to_device, device=device), batch)
self.microbatch_offset = [0 for _ in range(self.num_model_chunks)]
self.batch = batch
self.batch_size = get_batch_size(batch)
self.microbatch_offset = [0 for _ in range(self.num_model_chunks)]
assert self.batch_size % self.num_microbatches == 0, "Batch size should divided by the number of microbatches"
self.microbatch_size = self.batch_size // self.num_microbatches
if self.microbatch_size is None:
assert self.batch_size % self.num_microbatch == 0, "Batch size should divided by the number of microbatch"
self.microbatch_size = self.batch_size // self.num_microbatch
if self.num_microbatch is None:
assert self.batch_size % self.microbatch_size == 0, "Batch size should divided by the microbatch size"
self.num_microbatch = self.batch_size // self.microbatch_size
if not self.forward_only:
assert self.last_batch_size is None or self.last_batch_size == self.batch_size
assert self.batch_size == self.microbatch_size * self.num_microbatch
assert (
self.num_microbatch % self.stage_manager.num_stages == 0
), "Number of microbatch should be an integer multiple of number of pipeline parallel devices"
if self.forward_only:
self.num_microbatch = (self.batch_size - 1) // self.microbatch_size + 1
# NOTE: disable metadata cache when batch size changes (not valid anymore)
if self.batch_size != self.last_batch_size:
self.enable_metadata_cache = False
self.send_tensor_metadata = True
self.send_grad_metadata = True
self.tensor_metadata_recv = None
self.grad_metadata_recv = None
self.last_batch_size = self.batch_size
def load_micro_batch(self, model_chunk_id: int) -> Any:
"""Load a micro batch from the current batch.
@@ -54,11 +97,12 @@ class InterleavedSchedule(PipelineSchedule):
Returns:
Any: Micro batch.
"""
assert self.microbatch_offset[model_chunk_id] <= self.batch_size, "Microbatches exhausted"
micro_batch = get_micro_batch(self.batch, self.microbatch_offset[model_chunk_id], self.microbatch_size)
self.microbatch_offset[model_chunk_id] += self.microbatch_size
return tree_map(partial(to_device, device=get_accelerator().get_current_device()), micro_batch)
def get_model_chunk_id(self, microbatch_id: int, forward: bool) -> int:
def get_model_chunk_id(self, microbatch_id: int, is_forward: bool) -> int:
"""Helper method to get the model chunk ID given the iteration number.
Args:
@@ -68,38 +112,13 @@ class InterleavedSchedule(PipelineSchedule):
Returns:
int: The model chunk idx of the input microbatch_id
"""
microbatch_id_in_group = (microbatch_id) % (self.stage_manager.num_stages * self.num_model_chunks)
assert microbatch_id < self.num_microbatch * self.num_model_chunks
microbatch_id_in_group = microbatch_id % (self.stage_manager.num_stages * self.num_model_chunks)
model_chunk_id = microbatch_id_in_group // self.stage_manager.num_stages
if not forward:
if not is_forward:
model_chunk_id = self.num_model_chunks - model_chunk_id - 1
return model_chunk_id
def is_first_stage(self, model_chunk_id: int) -> bool:
"""Is the current virtual stage the first stage
Args:
model_chunk_id (int): The current model chunk idx.
Returns:
bool: Whether the current virtual stage is the first stage.
"""
if self.stage_manager.is_first_stage() and model_chunk_id == 0:
return True
return False
def is_last_stage(self, model_chunk_id: int) -> bool:
"""Is the current virtual stage the last stage
Args:
model_chunk_id (int): The current model chunk idx.
Returns:
bool: Whether the current virtual stage is the last stage.
"""
if self.stage_manager.is_last_stage() and model_chunk_id == self.num_model_chunks - 1:
return True
return False
def recv_forward(self, model_chunk_id: int, prev_rank: int = None) -> Any:
"""Copy the forward output from the previous stage in pipeline as the input tensor of this stage.
For interleaved 1F1B.
@@ -111,12 +130,13 @@ class InterleavedSchedule(PipelineSchedule):
Returns:
Any: The input tensor or input tensor list.
"""
if self.is_first_stage(model_chunk_id):
input_tensor = None
else:
input_tensor = self.comm.recv_forward(prev_rank)
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
if not self.stage_manager.is_first_stage():
input_tensor = self.comm.recv_forward(prev_rank, metadata_recv=self.tensor_metadata_recv)
if self.enable_metadata_cache and self.tensor_metadata_recv is None:
self.tensor_metadata_recv = create_send_metadata(input_tensor)
return input_tensor
return input_tensor
def recv_backward(self, model_chunk_id: int, next_rank: int = None) -> Any:
"""Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage.
@@ -129,14 +149,15 @@ class InterleavedSchedule(PipelineSchedule):
Returns:
Any: The input gradient tensor or gradient tensor list.
"""
if self.is_last_stage(model_chunk_id):
output_tensor_grad = None
else:
output_tensor_grad = self.comm.recv_backward(next_rank)
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
if not self.stage_manager.is_last_stage():
output_tensor_grad = self.comm.recv_backward(next_rank, metadata_recv=self.grad_metadata_recv)
if self.enable_metadata_cache and self.grad_metadata_recv is None:
self.grad_metadata_recv = create_send_metadata(output_tensor_grad)
return output_tensor_grad
return output_tensor_grad
def send_forward(self, model_chunk_id, output_object: Any, next_rank: int = None) -> None:
def send_forward(self, model_chunk_id: int, output_tensor: Any, next_rank: int = None) -> None:
"""Sends the input tensor to the next stage in pipeline.
For interleaved 1F1B.
@@ -145,10 +166,12 @@ class InterleavedSchedule(PipelineSchedule):
output_object (Any): Object to be sent.
next_rank (int, optional): The rank of the recipient of the tensor.
"""
if not self.is_last_stage(model_chunk_id):
self.comm.send_forward(output_object, next_rank)
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
if not self.stage_manager.is_last_stage():
self.comm.send_forward(output_tensor, next_rank, send_metadata=self.send_tensor_metadata)
self.send_tensor_metadata = not self.enable_metadata_cache
def send_backward(self, model_chunk_id, input_object: Any, prev_rank: int = None) -> None:
def send_backward(self, model_chunk_id: int, input_tensor_grad: Any, prev_rank: int = None) -> None:
"""Sends the gradient tensor to the previous stage in pipeline.
For interleaved 1F1B.
@@ -157,12 +180,102 @@ class InterleavedSchedule(PipelineSchedule):
input_object (Any): Object to be sent.
prev_rank (int, optional): The rank of the recipient of the tensor
"""
if not self.is_first_stage(model_chunk_id):
self.comm.send_backward(input_object, prev_rank)
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
if not self.stage_manager.is_first_stage():
self.comm.send_backward(input_tensor_grad, prev_rank, send_metadata=self.send_grad_metadata)
self.send_grad_metadata = not self.enable_metadata_cache
def send_forward_recv_backward(
self,
model_chunk_id_send: int,
model_chunk_id_recv: int,
output_tensor: Any,
next_rank: Optional[int] = None,
send_prior_fallback: Optional[bool] = None,
) -> Any:
with self.stage_manager.switch_model_chunk_id(model_chunk_id_send):
send_data = not self.stage_manager.is_last_stage()
with self.stage_manager.switch_model_chunk_id(model_chunk_id_recv):
recv_data = not self.stage_manager.is_last_stage()
if send_data and recv_data:
if not self.send_forward_recv_backward and self.grad_metadata_recv is not None:
send_prior_fallback = None # must not fallback
output_tensor_grad = self.comm.send_forward_recv_backward(
output_tensor,
next_rank,
send_metadata=self.send_tensor_metadata,
metadata_recv=self.grad_metadata_recv,
send_prior_fallback=send_prior_fallback,
)
self.send_tensor_metadata = not self.enable_metadata_cache
if self.enable_metadata_cache and self.grad_metadata_recv is None:
self.grad_metadata_recv = create_send_metadata(output_tensor_grad)
return output_tensor_grad
# send only or recv only
self.send_forward(model_chunk_id_send, output_tensor)
return self.recv_backward(model_chunk_id_recv)
def send_backward_recv_forward(
self,
model_chunk_id_send: int,
model_chunk_id_recv: int,
input_tensor_grad: Any,
prev_rank: Optional[int] = None,
send_prior_fallback: Optional[bool] = None,
) -> Any:
with self.stage_manager.switch_model_chunk_id(model_chunk_id_send):
send_data = not self.stage_manager.is_first_stage()
with self.stage_manager.switch_model_chunk_id(model_chunk_id_recv):
recv_data = not self.stage_manager.is_first_stage()
if send_data and recv_data:
if not self.send_backward_recv_backward and self.tensor_metadata_recv is not None:
send_prior_fallback = None # must not fallback
input_tensor = self.comm.send_backward_recv_forward(
input_tensor_grad,
prev_rank,
send_metadata=self.send_grad_metadata,
metadata_recv=self.tensor_metadata_recv,
send_prior_fallback=send_prior_fallback,
)
self.send_grad_metadata = not self.enable_metadata_cache
if self.enable_metadata_cache and self.tensor_metadata_recv is None:
self.tensor_metadata_recv = create_send_metadata(input_tensor)
return input_tensor
# send only or recv only
self.send_backward(model_chunk_id_send, input_tensor_grad)
return self.recv_forward(model_chunk_id_recv)
def send_forward_recv_forward(
self, model_chunk_id_send: int, model_chunk_id_recv: int, output_tensor: Any, send_prior: bool
):
if send_prior:
self.send_forward(model_chunk_id_send, output_tensor)
input_tensor = self.recv_forward(model_chunk_id_recv)
else:
input_tensor = self.recv_forward(model_chunk_id_recv)
self.send_forward(model_chunk_id_send, output_tensor)
return input_tensor
def send_backward_recv_backward(
self, model_chunk_id_send: int, model_chunk_id_recv: int, input_tensor_grad: Any, send_prior: bool
):
if send_prior:
self.send_backward(model_chunk_id_send, input_tensor_grad)
output_tensor_grad = self.recv_backward(model_chunk_id_recv)
else:
output_tensor_grad = self.recv_backward(model_chunk_id_recv)
self.send_backward(model_chunk_id_send, input_tensor_grad)
return output_tensor_grad
def forward_step(
self,
model_chunk: Module,
model_chunk: Union[ModuleList, Module],
model_chunk_id: int,
input_obj: Optional[dict],
criterion: Callable,
@@ -171,7 +284,7 @@ class InterleavedSchedule(PipelineSchedule):
) -> Union[torch.Tensor, dict]:
"""Forward one step of the pipeline
Args:
model (Module): Model Chunk to be run
model (ModuleList or Module): Model Chunk to be run
input_obj (Optional[dict]): The output from the previous stage. If it is the first stage, the `input_obj` is None.
criterion (Callable): Criterion to calculate loss.
accum_loss (Optional[torch.Tensor], optional): Accumulated loss. Defaults to None.
@@ -184,17 +297,25 @@ class InterleavedSchedule(PipelineSchedule):
# for the first stage, input_obj is None
# for the non-first stage, input_obj is the output of the previous stage and it's must be a dict
output_obj = model_forward(model_chunk[model_chunk_id], micro_batch, input_obj)
if self.is_last_stage(model_chunk_id):
loss = criterion(output_obj, micro_batch) / self.num_microbatches
if accum_loss is not None:
accum_loss.add_(loss.detach())
if outputs is not None:
outputs.append(tree_map(detach, output_obj))
return loss
else:
return output_obj
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
if isinstance(model_chunk, ModuleList):
output_obj = model_forward(model_chunk[model_chunk_id], micro_batch, input_obj)
else:
# NOTE: in shardformer, each device still has the entire model, so we need to use relevant stage layers
internal_inputs = {} if input_obj is None else input_obj
internal_inputs["stage_index"] = self.stage_manager.stage_indices[model_chunk_id]
output_obj = model_forward(model_chunk, micro_batch, internal_inputs)
if self.stage_manager.is_last_stage():
loss = criterion(output_obj, micro_batch) / self.num_microbatch
if accum_loss is not None:
accum_loss.add_(loss.detach())
if outputs is not None:
outputs.append(tree_map(detach, output_obj))
return loss
else:
return output_obj
def backward_step(
self,
@@ -241,19 +362,193 @@ class InterleavedSchedule(PipelineSchedule):
input_obj_grad[k] = v.grad
return input_obj_grad
def run_forward_only(
self,
model_chunk: Union[ModuleList, Module],
data_iter: Iterable,
criterion: Callable[..., Any],
return_loss: bool = False,
return_outputs: bool = False,
) -> Dict:
assert self.forward_only
self.load_batch(data_iter)
outputs = [] if return_outputs and self.stage_manager.is_last_stage(ignore_chunk=True) else None
accum_loss = None
if return_loss and self.stage_manager.is_last_stage(ignore_chunk=True):
accum_loss = torch.scalar_tensor(0, device=get_current_device())
model_chunk_id = self.get_model_chunk_id(0, is_forward=True)
input_obj = self.recv_forward(model_chunk_id)
for i in range(self.num_microbatch * self.num_model_chunks):
last_iteration = i == self.num_microbatch * self.num_model_chunks - 1
model_chunk_id = self.get_model_chunk_id(i, is_forward=True)
output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)
if not last_iteration:
input_obj = self.send_forward_recv_forward(
model_chunk_id_send=model_chunk_id,
model_chunk_id_recv=self.get_model_chunk_id(i + 1, is_forward=True),
output_tensor=output_obj,
send_prior=self.stage_manager.stage % 2 == 0,
)
else:
self.send_forward(model_chunk_id, output_obj)
if outputs is not None:
outputs = merge_batch(outputs)
return {"loss": accum_loss, "outputs": outputs}
def run_forward_backward(
self,
model_chunk: Union[ModuleList, Module],
data_iter: Iterable,
criterion: Callable[..., Any],
optimizer: Optional[OptimizerWrapper] = None,
return_loss: bool = False,
return_outputs: bool = False,
) -> Dict:
"""
Runs interleaved schedule, with communication between pipeline stages.
"""
assert not self.forward_only
self.load_batch(data_iter)
num_microbatch = self.num_microbatch * self.num_model_chunks
num_warmup_microbatch = (self.stage_manager.num_stages - self.stage_manager.stage - 1) * 2
num_warmup_microbatch += (self.num_model_chunks - 1) * self.stage_manager.num_stages
num_warmup_microbatch = min(num_warmup_microbatch, num_microbatch)
num_microbatch_remaining = num_microbatch - num_warmup_microbatch
# Input, output tensors only need to be saved when doing backward passes
input_objs = [[] for _ in range(self.num_model_chunks)]
output_objs = [[] for _ in range(self.num_model_chunks)]
outputs = [] if return_outputs and self.stage_manager.is_last_stage(ignore_chunk=True) else None
accum_loss = None
if return_loss and self.stage_manager.is_last_stage(ignore_chunk=True):
accum_loss = torch.scalar_tensor(0, device=get_current_device())
model_chunk_id = self.get_model_chunk_id(0, is_forward=True)
input_obj = self.recv_forward(model_chunk_id)
# Run warmup forward passes.
for i in range(num_warmup_microbatch):
last_iteration = i == num_warmup_microbatch - 1
model_chunk_id = self.get_model_chunk_id(i, is_forward=True)
output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)
input_objs[model_chunk_id].append(input_obj)
output_objs[model_chunk_id].append(output_obj)
if last_iteration and num_microbatch_remaining == 0:
self.send_forward(model_chunk_id, output_obj)
else:
input_obj = self.send_forward_recv_forward(
model_chunk_id_send=model_chunk_id,
model_chunk_id_recv=self.get_model_chunk_id(i + 1, is_forward=True),
output_tensor=output_obj,
send_prior=self.stage_manager.stage % 2 == 0,
)
if num_microbatch_remaining > 0:
model_chunk_id = self.get_model_chunk_id(0, is_forward=False)
output_obj_grad = self.recv_backward(model_chunk_id)
# Run 1F1B in steady state.
for i in range(num_microbatch_remaining):
last_iteration = i == num_microbatch_remaining - 1
model_chunk_id = self.get_model_chunk_id(i + num_warmup_microbatch, is_forward=True)
output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)
# Add input_obj and output_obj to end of list.
input_objs[model_chunk_id].append(input_obj)
output_objs[model_chunk_id].append(output_obj)
model_chunk_id = self.get_model_chunk_id(i, is_forward=False)
# Pop output_obj and output_obj from the start of the list for the backward pass.
_input_obj = input_objs[model_chunk_id].pop(0)
_output_obj = output_objs[model_chunk_id].pop(0)
input_obj_grad = self.backward_step(optimizer, _input_obj, _output_obj, output_obj_grad)
# NOTE: perform 2x communication for forward and backward
def send_forward_recv_backward():
if last_iteration and num_microbatch == num_microbatch_remaining:
model_chunk_id = self.get_model_chunk_id(i + num_warmup_microbatch, is_forward=True)
self.send_forward(model_chunk_id, output_obj)
else:
output_obj_grad = self.send_forward_recv_backward(
model_chunk_id_send=self.get_model_chunk_id(i + num_warmup_microbatch, is_forward=True),
model_chunk_id_recv=self.get_model_chunk_id(i + 1, is_forward=False),
output_tensor=output_obj,
send_prior_fallback=self.stage_manager.stage % 2 == 0,
)
return output_obj_grad
def send_backward_recv_forward():
if last_iteration:
model_chunk_id = self.get_model_chunk_id(i, is_forward=False)
self.send_backward(model_chunk_id, input_obj_grad)
else:
input_obj = self.send_backward_recv_forward(
model_chunk_id_send=self.get_model_chunk_id(i, is_forward=False),
model_chunk_id_recv=self.get_model_chunk_id(i + num_warmup_microbatch + 1, is_forward=True),
input_tensor_grad=input_obj_grad,
send_prior_fallback=self.stage_manager.stage % 2 == 0 and i > 0,
)
return input_obj
if self.stage_manager.stage % 2 == 0:
output_obj_grad = send_forward_recv_backward()
input_obj = send_backward_recv_forward()
else:
input_obj = send_backward_recv_forward()
output_obj_grad = send_forward_recv_backward()
if num_microbatch_remaining == 0:
model_chunk_id = self.get_model_chunk_id(0, is_forward=False)
output_obj_grad = self.recv_backward(model_chunk_id)
# Run cooldown backward passes.
for i in range(num_microbatch_remaining, num_microbatch):
last_iteration = i == num_microbatch - 1
model_chunk_id = self.get_model_chunk_id(i, is_forward=False)
_input_obj = input_objs[model_chunk_id].pop(0)
_output_obj = output_objs[model_chunk_id].pop(0)
# output_obj_grad = self.recv_backward(model_chunk_id)
input_obj_grad = self.backward_step(optimizer, _input_obj, _output_obj, output_obj_grad)
if not last_iteration:
output_obj_grad = self.send_backward_recv_backward(
model_chunk_id_send=self.get_model_chunk_id(i, is_forward=False),
model_chunk_id_recv=self.get_model_chunk_id(i + 1, is_forward=False),
input_tensor_grad=input_obj_grad,
send_prior=self.stage_manager.stage % 2 == 0 and i > num_microbatch_remaining,
)
else:
model_chunk_id = self.get_model_chunk_id(i, is_forward=False)
self.send_backward(model_chunk_id, input_obj_grad)
assert all(len(v) == 0 for v in input_objs) and all(len(v) == 0 for v in output_objs)
if outputs is not None:
outputs = merge_batch(outputs)
return {"loss": accum_loss, "outputs": outputs}
def forward_backward_step(
self,
model_chunk: Module,
model_chunk: Union[ModuleList, Module],
data_iter: Iterable,
criterion: Callable[..., Any],
optimizer: Optional[OptimizerWrapper] = None,
return_loss: bool = False,
return_outputs: bool = False,
) -> dict:
"""Runs interleaved 1F1B schedule, with communication between pipeline stages.
"""
Args:
model_chunk (List[Module]): Model Chunk to be trained.
model_chunk (ModuleList or Module): Model Chunk to be trained. Original interleaved uses a module list whereas shardformer uses entire model + layer specification
data_iter (Iterable): Data iterator.
criterion (Callable[[Any, Any], Tensor]): Criterion to be used. It should take two arguments: model outputs and inputs, and returns loss tensor.
optimizer (OptimizerWrapper, optional): Optimizer to be used. Can be None when only forward is executed. Defaults to None.
@@ -263,118 +558,15 @@ class InterleavedSchedule(PipelineSchedule):
Returns:
dict: A dict with keys: 'loss' and 'outputs'.
"""
forward_only = not torch.is_grad_enabled()
self.forward_only = not torch.is_grad_enabled()
if optimizer is None:
assert forward_only, "Optimizer should be passed when doing backward."
assert self.forward_only, "Optimizer should be passed when doing backward."
self.load_batch(data_iter)
num_model_chunks = len(model_chunk)
# num_warmup_microbatches is the step when not all the processes are working
num_microbatches = self.num_microbatches * num_model_chunks
if forward_only:
num_warmup_microbatches = num_microbatches
if self.forward_only:
result = self.run_forward_only(model_chunk, data_iter, criterion, return_loss, return_outputs)
else:
num_warmup_microbatches = (self.stage_manager.num_stages - self.stage_manager.stage - 1) * 2
num_warmup_microbatches += (num_model_chunks - 1) * self.stage_manager.num_stages
num_warmup_microbatches = min(num_warmup_microbatches, num_microbatches)
result = self.run_forward_backward(
model_chunk, data_iter, criterion, optimizer, return_loss, return_outputs
)
num_microbatches_remaining = num_microbatches - num_warmup_microbatches
# Input, output tensors only need to be saved when doing backward passes
input_objs = None
output_objs = None
if not forward_only:
input_objs = [[] for _ in range(num_model_chunks)]
output_objs = [[] for _ in range(num_model_chunks)]
outputs = [] if return_outputs and self.stage_manager.is_last_stage() else None
if return_loss and self.stage_manager.is_last_stage():
accum_loss = torch.zeros(1, device=get_accelerator().get_current_device())
else:
accum_loss = None
# for ranks except the first one, get into recv state
# print(self.stage_manager.stage,num_microbatches, num_warmup_microbatches, num_microbatches_remaining)
input_obj = self.recv_forward(0)
input_objs[0].append(input_obj)
# Run warmup forward passes.
for i in range(num_warmup_microbatches):
model_chunk_id = self.get_model_chunk_id(i, forward=True)
# recv first on first rank to avoid sending or recving at the same time
if self.stage_manager.is_first_stage():
input_obj = self.recv_forward(model_chunk_id)
output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)
self.send_forward(model_chunk_id, output_obj)
if not forward_only:
input_objs[model_chunk_id].append(input_obj)
output_objs[model_chunk_id].append(output_obj)
else:
output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)
if not forward_only:
output_objs[model_chunk_id].append(output_obj)
self.send_forward(model_chunk_id, output_obj)
if num_microbatches_remaining == 0 and i + 1 == num_warmup_microbatches:
break
else:
model_chunk_id = self.get_model_chunk_id(i + 1, forward=True)
input_obj = self.recv_forward(model_chunk_id)
if not forward_only:
input_objs[model_chunk_id].append(input_obj)
# Run 1F1B in steady state.
for i in range(num_microbatches_remaining):
model_chunk_id = self.get_model_chunk_id(i + num_warmup_microbatches, forward=True)
last_iteration = i == (num_microbatches_remaining - 1)
output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)
if forward_only:
self.send_forward(model_chunk_id, output_obj)
if not last_iteration:
input_obj = self.recv_forward(model_chunk_id)
else:
self.send_forward(model_chunk_id, output_obj)
# Add input_obj and output_obj to end of list.
input_objs[model_chunk_id].append(input_obj)
output_objs[model_chunk_id].append(output_obj)
model_chunk_id = self.get_model_chunk_id(i, forward=False)
output_obj_grad = self.recv_backward(model_chunk_id)
# Pop output_obj and output_obj from the start of the list for
# the backward pass.
input_obj = input_objs[model_chunk_id].pop(0)
output_obj = output_objs[model_chunk_id].pop(0)
# backward
input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad)
if last_iteration:
input_obj = None
else:
model_chunk_id = self.get_model_chunk_id(i + num_warmup_microbatches + 1, forward=True)
input_obj = self.recv_forward(model_chunk_id)
model_chunk_id = self.get_model_chunk_id(i, forward=False)
self.send_backward(model_chunk_id, input_obj_grad)
# Run cooldown backward passes.
if not forward_only:
for i in range(num_microbatches_remaining, num_microbatches):
model_chunk_id = self.get_model_chunk_id(i, forward=False)
# print(f"{self.stage_manager.stage}/{model_chunk_id}: {len(input_objs[model_chunk_id])} {len(output_objs[model_chunk_id])} {i}")
input_obj = input_objs[model_chunk_id].pop(0)
output_obj = output_objs[model_chunk_id].pop(0)
output_obj_grad = self.recv_backward(model_chunk_id)
input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad)
self.send_backward(model_chunk_id, input_obj_grad)
if outputs is not None:
outputs = merge_batch(outputs)
return {"loss": accum_loss, "outputs": outputs}
return result

View File

@@ -1,5 +1,5 @@
from functools import partial
from typing import Any, Callable, Iterable, List, Optional, Union
from typing import Any, Callable, Dict, Iterable, List, Optional, Union
import torch
import torch.cuda
@@ -8,7 +8,7 @@ from torch.utils._pytree import tree_map
from colossalai.accelerator import get_accelerator
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.pipeline.p2p import PipelineP2PCommunication
from colossalai.pipeline.p2p import PipelineP2PCommunication, create_send_metadata
from colossalai.pipeline.stage_manager import PipelineStageManager
from ._utils import (
@@ -30,6 +30,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
stage_manager: PipelineStageManager,
num_microbatches: Optional[int] = None,
microbatch_size: Optional[int] = None,
enable_metadata_cache: bool = True,
) -> None:
"""1F1B pipeline schedule.
@@ -42,13 +43,21 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
assert (
num_microbatches is not None or microbatch_size is not None
), "Either num_microbatches or microbatch_size should be provided"
self.comm = PipelineP2PCommunication(stage_manager)
self.num_microbatches = num_microbatches
self.microbatch_size = microbatch_size
self.batch: Optional[Any] = None
self.batch_size: Optional[int] = None
self.last_batch_size: Optional[int] = None
self.microbatch_offset: Optional[int] = None
self._use_microbatch_size = num_microbatches is None
# P2PMeta cache
self.enable_metadata_cache = enable_metadata_cache
self.send_tensor_metadata = True
self.send_grad_metadata = True
self.tensor_metadata_recv = None
self.grad_metadata_recv = None
def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None:
"""Load a batch from data iterator.
@@ -60,24 +69,45 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
batch = next(data_iter)
if device is not None:
batch = tree_map(partial(to_device, device=device), batch)
self.microbatch_offset = 0
self.batch = batch
self.batch_size = get_batch_size(batch)
self.microbatch_offset = 0
if not self._use_microbatch_size:
assert (
self.batch_size % self.num_microbatches == 0
), "Batch size should divided by the number of microbatches"
if self.microbatch_size is None:
assert self.batch_size % self.num_microbatches == 0, "Batch size should divided by # microbatches"
self.microbatch_size = self.batch_size // self.num_microbatches
else:
if self.num_microbatches is None:
assert self.batch_size % self.microbatch_size == 0, "Batch size should divided by the microbatch size"
self.num_microbatches = self.batch_size // self.microbatch_size
if not self.forward_only:
assert self.last_batch_size is None or self.last_batch_size == self.batch_size
assert self.batch_size == self.microbatch_size * self.num_microbatches
assert (
self.num_microbatches >= self.stage_manager.num_stages
), "Number of microbatch should be larger than number of stages"
if self.forward_only:
self.num_microbatches = (self.batch_size - 1) // self.microbatch_size + 1
# NOTE: disable metadata cache when batch size changes (not valid anymore)
if self.batch_size != self.last_batch_size:
self.enable_metadata_cache = False
self.send_tensor_metadata = True
self.send_grad_metadata = True
self.tensor_metadata_recv = None
self.grad_metadata_recv = None
self.last_batch_size = self.batch_size
def load_micro_batch(self) -> Any:
"""Load a micro batch from the current batch.
Returns:
Any: Micro batch.
"""
assert self.microbatch_offset <= self.batch_size, "Microbatches exhausted"
micro_batch = get_micro_batch(self.batch, self.microbatch_offset, self.microbatch_size)
self.microbatch_offset += self.microbatch_size
return tree_map(partial(to_device, device=get_accelerator().get_current_device()), micro_batch)
@@ -92,12 +122,12 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
Returns:
Any: The input tensor or input tensor list.
"""
if self.stage_manager.is_first_stage():
input_tensor = None
else:
input_tensor = self.comm.recv_forward(prev_rank)
if not self.stage_manager.is_first_stage():
input_tensor = self.comm.recv_forward(prev_rank, metadata_recv=self.tensor_metadata_recv)
if self.enable_metadata_cache and self.tensor_metadata_recv is None:
self.tensor_metadata_recv = create_send_metadata(input_tensor)
return input_tensor
return input_tensor
def recv_backward(self, next_rank: int = None) -> Any:
"""Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage.
@@ -109,14 +139,14 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
Returns:
Any: The input gradient tensor or gradient tensor list.
"""
if self.stage_manager.is_last_stage():
output_tensor_grad = None
else:
output_tensor_grad = self.comm.recv_backward(next_rank)
if not self.stage_manager.is_last_stage():
output_tensor_grad = self.comm.recv_backward(next_rank, metadata_recv=self.grad_metadata_recv)
if self.enable_metadata_cache and self.grad_metadata_recv is None:
self.grad_metadata_recv = create_send_metadata(output_tensor_grad)
return output_tensor_grad
return output_tensor_grad
def send_forward(self, output_object: Any, next_rank: int = None) -> None:
def send_forward(self, output_tensor: Any, next_rank: int = None) -> None:
"""Sends the input tensor to the next stage in pipeline.
For 1F1B.
@@ -125,20 +155,10 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
next_rank (int, optional): The rank of the recipient of the tensor.
"""
if not self.stage_manager.is_last_stage():
self.comm.send_forward(output_object, next_rank)
self.comm.send_forward(output_tensor, next_rank, send_metadata=self.send_tensor_metadata)
self.send_tensor_metadata = not self.enable_metadata_cache
def send_forward_recv_backward(self, output_object: Any, next_rank: int = None) -> Any:
"""Sends the input tensor to the next stage and copy the gradient tensor from the next stage in pipeline.
For 1F1B.
Args:
output_object (Any): Object to be sent.
next_rank (int, optional): The rank of the recipient of the tensor.
"""
if not self.stage_manager.is_last_stage():
return self.comm.send_forward_recv_backward(output_object, next_rank)
def send_backward(self, input_object: Any, prev_rank: int = None) -> None:
def send_backward(self, input_tensor_grad: Any, prev_rank: int = None) -> None:
"""Sends the gradient tensor to the previous stage in pipeline.
For 1F1B.
@@ -147,9 +167,38 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
prev_rank (int, optional): The rank of the recipient of the tensor
"""
if not self.stage_manager.is_first_stage():
self.comm.send_backward(input_object, prev_rank)
self.comm.send_backward(input_tensor_grad, prev_rank, send_metadata=self.send_grad_metadata)
self.send_grad_metadata = not self.enable_metadata_cache
def send_backward_recv_forward(self, output_object: Any, prev_rank: int = None) -> Any:
def send_forward_recv_backward(
self, output_tensor: Any, next_rank: int = None, send_prior_fallback: Optional[bool] = None
) -> Any:
"""Sends the input tensor to the next stage and copy the gradient tensor from the next stage in pipeline.
For 1F1B.
Args:
output_object (Any): Object to be sent.
next_rank (int, optional): The rank of the recipient of the tensor.
"""
if not self.stage_manager.is_last_stage():
if not self.send_tensor_metadata and self.grad_metadata_recv is not None:
send_prior_fallback = None # must not fallback
output_tensor_grad = self.comm.send_forward_recv_backward(
output_tensor,
next_rank,
send_metadata=self.send_tensor_metadata,
metadata_recv=self.grad_metadata_recv,
send_prior_fallback=send_prior_fallback,
)
self.send_tensor_metadata = not self.enable_metadata_cache
if self.enable_metadata_cache and self.grad_metadata_recv is None:
self.grad_metadata_recv = create_send_metadata(output_tensor_grad)
return output_tensor_grad
def send_backward_recv_forward(
self, input_tensor_grad: Any, prev_rank: int = None, send_prior_fallback: Optional[bool] = None
) -> Any:
"""Sends the gradient tensor to the previous stage and copy the input tensor from the previous stage in pipeline.
For 1F1B.
@@ -158,23 +207,20 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
prev_rank (int, optional): The rank of the recipient of the tensor.
"""
if not self.stage_manager.is_first_stage():
return self.comm.send_backward_recv_forward(output_object, prev_rank)
if not self.send_grad_metadata and self.tensor_metadata_recv is not None:
send_prior_fallback = None # must not fallback
input_tensor = self.comm.send_backward_recv_forward(
input_tensor_grad,
prev_rank,
send_metadata=self.send_grad_metadata,
metadata_recv=self.tensor_metadata_recv,
send_prior_fallback=send_prior_fallback,
)
self.send_grad_metadata = not self.enable_metadata_cache
if self.enable_metadata_cache and self.tensor_metadata_recv is None:
self.tensor_metadata_recv = create_send_metadata(input_tensor)
def send_forward_recv_forward(self, input_object: Any, prev_rank: int = None, next_rank: int = None) -> Any:
"""Sends the input tensor to the next stage and copy the input tensor from the previous stage in pipeline.
For 1F1B.
Args:
input_object (Any): Object to be sent.
prev_rank (int, optional): The previous rank of the recipient of the tensor.
next_rank (int, optional): The next rank of the recipient of the tensor.
"""
if self.stage_manager.is_first_stage():
return self.comm.send_forward(input_object, next_rank)
elif self.stage_manager.is_last_stage():
return self.comm.recv_forward(prev_rank)
else:
return self.comm.send_forward_recv_forward(input_object, prev_rank, next_rank)
return input_tensor
def forward_step(
self,
@@ -254,7 +300,38 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
input_obj_grad[k] = v.grad
return input_obj_grad
def forward_backward_step(
def run_forward_only(
self,
model: Module,
data_iter: Iterable,
criterion: Callable[..., Any],
return_loss: bool = False,
return_outputs: bool = False,
) -> Dict:
"""
Runs forward only schedule, with communication between pipeline stages.
"""
assert self.forward_only
self.load_batch(data_iter)
accum_loss = None
if return_loss and self.stage_manager.is_last_stage():
accum_loss = torch.scalar_tensor(0, device=get_accelerator().get_current_device())
outputs = [] if return_outputs and self.stage_manager.is_last_stage() else None
for _ in range(self.num_microbatches):
input_obj = self.recv_forward()
output_obj = self.forward_step(model, input_obj, criterion, accum_loss, outputs)
self.send_forward(output_obj)
if outputs is not None:
if isinstance(model, ModelWrapper):
model = model.unwrap()
outputs = merge_batch(outputs, getattr(model, "batch_size_dim", 0))
return {"loss": accum_loss, "outputs": outputs}
def run_forward_backward(
self,
model: Module,
data_iter: Iterable,
@@ -262,23 +339,11 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
optimizer: Optional[OptimizerWrapper] = None,
return_loss: bool = False,
return_outputs: bool = False,
) -> dict:
"""Runs non-interleaved 1F1B schedule, with communication between pipeline stages.
Args:
model (Module): Model to be trained.
data_iter (Iterable): Data iterator.
criterion (Callable[[Any, Any], Tensor]): Criterion to be used. It should take two arguments: model outputs and inputs, and returns loss tensor.
optimizer (OptimizerWrapper, optional): Optimizer to be used. Can be None when only forward is executed. Defaults to None.
return_loss (bool, optional): Whether to return loss. Defaults to False. Whether to return loss.
return_outputs (bool, optional): Whether to return model outputs. Defaults to False. Whether to return model outputs.
Returns:
dict: A dict with keys: 'loss' and 'outputs'.
) -> Dict:
"""
forward_only = not torch.is_grad_enabled()
if optimizer is None:
assert forward_only, "Optimizer should be passed when doing backward."
Runs non-interleaved 1F1B schedule, with communication between pipeline stages.
"""
assert not self.forward_only
self.load_batch(data_iter)
@@ -288,30 +353,20 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
num_microbatches_remaining = self.num_microbatches - num_warmup_microbatches
# Input, output tensors only need to be saved when doing backward passes
input_objs = None
output_objs = None
input_objs, output_objs = [], []
if not forward_only:
input_objs = []
output_objs = []
outputs = [] if return_outputs and self.stage_manager.is_last_stage() else None
accum_loss = None
if return_loss and self.stage_manager.is_last_stage():
accum_loss = torch.zeros(1, device=get_accelerator().get_current_device())
else:
accum_loss = None
accum_loss = torch.scalar_tensor(0, device=get_current_device())
outputs = [] if return_outputs and self.stage_manager.is_last_stage() else None
# Run warmup forward passes.
for i in range(num_warmup_microbatches):
input_obj = self.recv_forward()
output_obj = self.forward_step(model, input_obj, criterion, accum_loss, outputs)
self.send_forward(output_obj)
if not forward_only:
input_objs.append(input_obj)
output_objs.append(output_obj)
input_objs.append(input_obj)
output_objs.append(output_obj)
# Before running 1F1B, need to receive first forward tensor.
# If all microbatches are run in warmup / cooldown phase, then no need to
@@ -324,44 +379,72 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
last_iteration = i == (num_microbatches_remaining - 1)
output_obj = self.forward_step(model, input_obj, criterion, accum_loss, outputs)
if forward_only:
self.send_forward(output_obj)
output_obj_grad = self.send_forward_recv_backward(
output_obj, send_prior_fallback=self.stage_manager.stage % 2 == 0
)
# Add input_obj and output_obj to end of list.
input_objs.append(input_obj)
output_objs.append(output_obj)
if not last_iteration:
input_obj = self.recv_forward()
else:
# TODO adjust here
self.send_forward(output_obj)
output_obj_grad = self.recv_backward()
# Pop output_obj and output_obj from the start of the list for
# the backward pass.
input_obj = input_objs.pop(0)
output_obj = output_objs.pop(0)
input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad)
# Add input_obj and output_obj to end of list.
input_objs.append(input_obj)
output_objs.append(output_obj)
# Pop output_obj and output_obj from the start of the list for
# the backward pass.
input_obj = input_objs.pop(0)
output_obj = output_objs.pop(0)
input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad)
if last_iteration:
input_obj = None
else:
input_obj = self.recv_forward()
if last_iteration:
self.send_backward(input_obj_grad)
else:
input_obj = self.send_backward_recv_forward(
input_obj_grad, send_prior_fallback=self.stage_manager.stage % 2 == 0
)
# Run cooldown backward passes.
if not forward_only:
for i in range(num_warmup_microbatches):
input_obj = input_objs.pop(0)
output_obj = output_objs.pop(0)
for i in range(num_warmup_microbatches):
input_obj = input_objs.pop(0)
output_obj = output_objs.pop(0)
output_obj_grad = self.recv_backward()
input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad)
self.send_backward(input_obj_grad)
output_obj_grad = self.recv_backward()
input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad)
self.send_backward(input_obj_grad)
assert all(len(v) == 0 for v in input_objs) and all(len(v) == 0 for v in output_objs)
if outputs is not None:
if isinstance(model, ModelWrapper):
model = model.unwrap()
outputs = merge_batch(outputs, getattr(model, "batch_size_dim", 0))
return {"loss": accum_loss, "outputs": outputs}
def forward_backward_step(
self,
model: Module,
data_iter: Iterable,
criterion: Callable[..., Any],
optimizer: Optional[OptimizerWrapper] = None,
return_loss: bool = False,
return_outputs: bool = False,
) -> dict:
"""
Args:
model (Module): Model to be trained.
data_iter (Iterable): Data iterator.
criterion (Callable[[Any, Any], Tensor]): Criterion to be used. It should take two arguments: model outputs and inputs, and returns loss tensor.
optimizer (OptimizerWrapper, optional): Optimizer to be used. Can be None when only forward is executed. Defaults to None.
return_loss (bool, optional): Whether to return loss. Defaults to False. Whether to return loss.
return_outputs (bool, optional): Whether to return model outputs. Defaults to False. Whether to return model outputs.
Returns:
dict: Dictionary containing loss and outputs.
"""
self.forward_only = not torch.is_grad_enabled()
if optimizer is None:
assert self.forward_only, "Optimizer should be passed when doing backward."
if self.forward_only:
result = self.run_forward_only(model, data_iter, criterion, return_loss, return_outputs)
else:
result = self.run_forward_backward(model, data_iter, criterion, optimizer, return_loss, return_outputs)
return result

View File

@@ -1,3 +1,4 @@
import contextlib
from typing import Dict, List, Optional, Tuple
import torch.distributed as dist
@@ -19,7 +20,15 @@ class PipelineStageManager:
stage (int): The current stage.
"""
def __init__(self, pg_mesh: ProcessGroupMesh, pipeline_axis: int, is_virtual: bool = False) -> None:
def __init__(
self,
pg_mesh: ProcessGroupMesh,
pipeline_axis: int,
enable_interleave: bool = False,
num_model_chunks: int = 1,
) -> None:
assert enable_interleave or num_model_chunks == 1, "num_model_chunks must be 1 when enable_interleave is False"
self.pg_mesh = pg_mesh
self.pipeline_axis = pipeline_axis
self.prev_rank: Optional[Tuple[int, ...]] = None
@@ -43,29 +52,56 @@ class PipelineStageManager:
ranks_in_group = self.pg_mesh.get_ranks_in_group(group)
self.p2p_groups[tuple(ranks_in_group)] = group
if is_virtual:
self.is_interleave = enable_interleave
if enable_interleave:
# use circle p2p communication
# add the process group of the first rank and the last rank
# only used in interleaved pipeline for now
group = self.pg_mesh.get_group_along_axis(self.pipeline_axis, [stages[0], stages[-1]])
if self.stage in [stages[0], stages[-1]]:
ranks_in_group = self.pg_mesh.get_ranks_in_group(group)
self.p2p_groups[tuple(ranks_in_group)] = group
def is_first_stage(self) -> bool:
# for interleaved pipeline parallel, each device is responsible for multiple chunk of layers
self.num_model_chunks: int = num_model_chunks
# for shardformer, hold stage indices of model
self.stage_indices: List[Tuple[int, int]]
# for shardformer, hold model chunk id
self.model_chunk_id: Optional[int] = None
def is_first_stage(self, ignore_chunk: bool = False) -> bool:
"""Is the current stage the first stage.
NOTE:
1. if using interleaved pipeline parallel, the first stage is the first chunk of the first device.
2. invoke is_first_stage() with ignore_chunk=True is equivalent to invoke is_first_device()
Returns:
bool: Whether the current stage is the first stage.
"""
return self.stage == 0
assert isinstance(ignore_chunk, bool)
assert not self.is_interleave or (ignore_chunk or self.model_chunk_id is not None)
if not self.is_interleave or ignore_chunk:
return self.stage == 0
else:
return self.stage == 0 and self.model_chunk_id == 0
def is_last_stage(self) -> bool:
def is_last_stage(self, ignore_chunk: bool = False) -> bool:
"""Is the current stage the last stage.
NOTE:
1. if using interleaved pipeline parallel, the last stage is the last chunk of the last device.
2. invoke is_last_stage() with ignore_chunk=True is equivalent to invoke is_last_device()
Returns:
bool: Whether the current stage is the last stage.
"""
return self.stage == self.num_stages - 1
assert isinstance(ignore_chunk, bool)
assert not self.is_interleave or (ignore_chunk or self.model_chunk_id is not None)
if not self.is_interleave or ignore_chunk:
return self.stage == self.num_stages - 1
else:
return self.stage == self.num_stages - 1 and self.model_chunk_id == self.num_model_chunks - 1
@property
def num_stages(self) -> int:
@@ -133,3 +169,10 @@ class PipelineStageManager:
ProcessGroup: Process group of the given stages.
"""
return self.pg_mesh.get_group_along_axis(self.pipeline_axis, stages)
@contextlib.contextmanager
def switch_model_chunk_id(self, model_chunk_id: int):
old_model_chunk_id = self.model_chunk_id
self.model_chunk_id = model_chunk_id
yield
self.model_chunk_id = old_model_chunk_id