mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 18:40:28 +00:00
[Feature] optimize PP overlap (#5735)
* update to fully overlap, still debugging * improve interface * fixed deadlock bug * debug NaN loss * (experimental) use one comm group for send_fw_recv_fw to fix NaN * cleaned up interfaces; use one batch p2p for all * clean up; removed the double p2p batch case * p2p test passsed * improve overlap: send fwd before backward * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * tentatively use 2 p2p batches * remove two p2p batches * fix typos * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove pp.sh --------- Co-authored-by: Edenzzzz <wtan45@wisc.edu> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: root <root@notebook-c55824c0-7742-45e8-9591-c855bb77ad29-0.notebook-c55824c0-7742-45e8-9591-c855bb77ad29.colossal-ai.svc.cluster.local>
This commit is contained in:
@@ -225,31 +225,41 @@ def _batch_send_recv_tensor(
|
||||
send_group: Optional[ProcessGroup],
|
||||
recv_group: Optional[ProcessGroup],
|
||||
current_device: Any,
|
||||
overlap_p2p: bool = True,
|
||||
send_first: bool = True,
|
||||
) -> Optional[Union[torch.Tensor, List[torch.Tensor]]]:
|
||||
buffer_recv = None
|
||||
if recv_tensor_metadata is not None:
|
||||
buffer_recv = _create_recv_buffer(recv_tensor_metadata, current_device)
|
||||
|
||||
ops = []
|
||||
if send_dst is not None and send_tensor_list is not None:
|
||||
assert send_group is not None
|
||||
_filling_ops_queue(send_tensor_list, dist.isend, send_dst, ops, send_group)
|
||||
if recv_src is not None and buffer_recv is not None:
|
||||
assert recv_group is not None
|
||||
_filling_ops_queue(buffer_recv, dist.irecv, recv_src, ops, recv_group)
|
||||
is_send = send_dst is not None and send_tensor_list is not None
|
||||
is_recv = recv_src is not None and buffer_recv is not None
|
||||
|
||||
if send_first:
|
||||
if is_send:
|
||||
assert send_group is not None
|
||||
_filling_ops_queue(send_tensor_list, dist.isend, send_dst, ops, send_group)
|
||||
if is_recv:
|
||||
assert recv_group is not None
|
||||
_filling_ops_queue(buffer_recv, dist.irecv, recv_src, ops, recv_group)
|
||||
else:
|
||||
if is_recv:
|
||||
assert recv_group is not None
|
||||
_filling_ops_queue(buffer_recv, dist.irecv, recv_src, ops, recv_group)
|
||||
if is_send:
|
||||
assert send_group is not None
|
||||
_filling_ops_queue(send_tensor_list, dist.isend, send_dst, ops, send_group)
|
||||
|
||||
if len(ops) > 0:
|
||||
reqs = dist.batch_isend_irecv(ops)
|
||||
for req in reqs:
|
||||
req.wait()
|
||||
|
||||
# Remove synchronization according to Pytorch's documentation
|
||||
# However, the Megatron-LM does synchronization here
|
||||
# https://github.com/microsoft/Megatron-DeepSpeed/blob/ef13d099c2a1609225a4ce4c1a1753cc76dd90a1/megatron/p2p_communication.py#L111-L112
|
||||
# In case there is potential error, uncomment the following `torch.cuda.synchronize()`
|
||||
# torch.cuda.synchronize()
|
||||
|
||||
return buffer_recv
|
||||
if not overlap_p2p:
|
||||
for req in reqs:
|
||||
req.wait()
|
||||
return buffer_recv, []
|
||||
else:
|
||||
return buffer_recv, reqs
|
||||
return None, []
|
||||
|
||||
|
||||
def _send_recv_serialization_object(
|
||||
@@ -260,10 +270,11 @@ def _send_recv_serialization_object(
|
||||
recv_group: Optional[ProcessGroup],
|
||||
current_device: Any,
|
||||
is_nccl_backend: bool,
|
||||
send_first: bool = True,
|
||||
) -> Optional[P2PMetadata]:
|
||||
ops = []
|
||||
|
||||
send_object_tensor = None
|
||||
send_object_size_tensor = None
|
||||
if object is not None and send_dst is not None:
|
||||
if Version(torch.__version__) >= Version("1.13.0"):
|
||||
send_object_tensor, send_object_size_tensor = c10d._object_to_tensor(object, device=current_device)
|
||||
@@ -274,43 +285,54 @@ def _send_recv_serialization_object(
|
||||
send_object_size_tensor = send_object_size_tensor.to(current_device)
|
||||
send_object_tensor = send_object_tensor.to(current_device)
|
||||
|
||||
_filling_ops_queue(send_object_size_tensor, dist.isend, send_dst, ops, send_group)
|
||||
|
||||
recv_object_size_tensor = None
|
||||
if recv_src is not None:
|
||||
recv_object_size_tensor = torch.empty(1, dtype=torch.long)
|
||||
if is_nccl_backend:
|
||||
recv_object_size_tensor = recv_object_size_tensor.to(current_device)
|
||||
_filling_ops_queue(recv_object_size_tensor, dist.irecv, recv_src, ops, recv_group)
|
||||
|
||||
if send_first:
|
||||
if send_object_size_tensor is not None:
|
||||
_filling_ops_queue(send_object_size_tensor, dist.isend, send_dst, ops, send_group)
|
||||
if recv_src is not None:
|
||||
_filling_ops_queue(recv_object_size_tensor, dist.irecv, recv_src, ops, recv_group)
|
||||
else:
|
||||
if recv_src is not None:
|
||||
_filling_ops_queue(recv_object_size_tensor, dist.irecv, recv_src, ops, recv_group)
|
||||
if send_object_size_tensor is not None:
|
||||
_filling_ops_queue(send_object_size_tensor, dist.isend, send_dst, ops, send_group)
|
||||
|
||||
if len(ops) > 0:
|
||||
reqs = dist.batch_isend_irecv(ops)
|
||||
for req in reqs:
|
||||
req.wait()
|
||||
|
||||
# See the comment in `_batch_send_recv_tensor`
|
||||
# torch.cuda.synchronize()
|
||||
req.wait() # This blocks the compute stream in torch
|
||||
|
||||
ops = []
|
||||
|
||||
if send_dst is not None and send_object_tensor is not None:
|
||||
_filling_ops_queue(send_object_tensor, dist.isend, send_dst, ops, send_group)
|
||||
is_send = send_dst is not None and send_object_tensor is not None
|
||||
is_recv = recv_src is not None and recv_object_size_tensor is not None
|
||||
|
||||
recv_object_tensor = None
|
||||
if recv_src is not None and recv_object_size_tensor is not None:
|
||||
if is_recv:
|
||||
recv_object_tensor = torch.empty(recv_object_size_tensor.item(), dtype=torch.uint8)
|
||||
if is_nccl_backend:
|
||||
recv_object_tensor = recv_object_tensor.to(current_device)
|
||||
_filling_ops_queue(recv_object_tensor, dist.irecv, recv_src, ops, recv_group)
|
||||
|
||||
if send_first:
|
||||
if is_send:
|
||||
_filling_ops_queue(send_object_tensor, dist.isend, send_dst, ops, send_group)
|
||||
if is_recv:
|
||||
_filling_ops_queue(recv_object_tensor, dist.irecv, recv_src, ops, recv_group)
|
||||
else:
|
||||
if is_recv:
|
||||
_filling_ops_queue(recv_object_tensor, dist.irecv, recv_src, ops, recv_group)
|
||||
if is_send:
|
||||
_filling_ops_queue(send_object_tensor, dist.isend, send_dst, ops, send_group)
|
||||
|
||||
if len(ops) > 0:
|
||||
reqs = dist.batch_isend_irecv(ops)
|
||||
for req in reqs:
|
||||
req.wait()
|
||||
|
||||
# See the comment in `_batch_send_recv_tensor`
|
||||
# torch.cuda.synchronize()
|
||||
|
||||
if recv_object_tensor is not None and recv_object_size_tensor is not None:
|
||||
recv_object_tensor = recv_object_tensor.type(torch.uint8)
|
||||
if recv_object_tensor.device != torch.device("cpu"):
|
||||
@@ -328,11 +350,12 @@ def _communicate(
|
||||
object: Any,
|
||||
send_dst: Optional[int],
|
||||
recv_src: Optional[int],
|
||||
overlap_p2p: bool,
|
||||
send_group: Optional[ProcessGroup] = None,
|
||||
recv_group: Optional[ProcessGroup] = None,
|
||||
send_metadata: bool = True,
|
||||
metadata_recv: Optional[P2PMetadata] = None,
|
||||
send_prior_fallback: Optional[bool] = None,
|
||||
send_first: Optional[bool] = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Send and receive object from send_dst and recv_src respectively
|
||||
@@ -341,6 +364,7 @@ def _communicate(
|
||||
object (Any): object needed to be sent
|
||||
send_dst (int): rank of the destination
|
||||
recv_src (int): rank of the source
|
||||
overlap_p2p (bool): whether to overlap p2p communication with computation
|
||||
send_group (ProcessGroup, optional): process group of sender
|
||||
recv_group (ProcessGroup, optional): process group of receiver
|
||||
send_metadata (bool, optional): whether to send metadata
|
||||
@@ -358,32 +382,10 @@ def _communicate(
|
||||
# NOTE: if object contains non-tensor objects, we have to send metadata
|
||||
metadata_send, tensor_objs = create_send_metadata(object, strict=False, return_tensor=True)
|
||||
send_metadata = send_metadata or len(metadata_send.non_tensor_obj_idx) > 0
|
||||
else:
|
||||
send_metadata = False
|
||||
|
||||
# NOTE: send & recv should be atomic operations. However, if we need to send metadata or receive metadata,
|
||||
# we are not able to do that (1. send & recv metadata 2. send & recv). So we need to split the send & recv into two parts in this case.
|
||||
if (send_dst is not None and recv_src is not None) and (send_metadata or metadata_recv is None):
|
||||
assert send_prior_fallback is not None, "Priority must be set if fallback happens"
|
||||
if send_prior_fallback:
|
||||
_communicate(object, send_dst=send_dst, recv_src=None, send_group=send_group, send_metadata=send_metadata)
|
||||
return _communicate(
|
||||
None, send_dst=None, recv_src=recv_src, recv_group=recv_group, metadata_recv=metadata_recv
|
||||
)
|
||||
else:
|
||||
recv_data = _communicate(
|
||||
None, send_dst=None, recv_src=recv_src, recv_group=recv_group, metadata_recv=metadata_recv
|
||||
)
|
||||
_communicate(object, send_dst=send_dst, recv_src=None, send_group=send_group, send_metadata=send_metadata)
|
||||
return recv_data
|
||||
|
||||
# NOTE: only the following 5 cases are valid:
|
||||
# 1. send() [needs extra metadata] and no recv()
|
||||
# 2. recv() [needs extra metadata] and no send()
|
||||
# 3. neither send() nor recv() need extra metadata
|
||||
assert not (send_dst is not None and send_metadata) or recv_src is None
|
||||
assert not (recv_src is not None and metadata_recv is None) or send_dst is None
|
||||
assert not (send_dst is not None and recv_src is not None) or (not send_metadata and metadata_recv is not None)
|
||||
assert not c10d._rank_not_in_group(send_group) and not c10d._rank_not_in_group(recv_group)
|
||||
|
||||
current_send_device, is_send_nccl_backend = _check_device(send_group)
|
||||
current_recv_device, is_recv_nccl_backend = _check_device(recv_group)
|
||||
|
||||
@@ -402,14 +404,25 @@ def _communicate(
|
||||
recv_group=recv_group if metadata_recv is None else None,
|
||||
current_device=current_device,
|
||||
is_nccl_backend=is_nccl_backend,
|
||||
send_first=send_first if send_first != None else True,
|
||||
)
|
||||
assert metadata_recv is None or _metadata_recv is None
|
||||
assert (
|
||||
metadata_recv is None or _metadata_recv is None
|
||||
), "You shouldn't receive metadata when using the cached metadata"
|
||||
metadata_recv = _metadata_recv if metadata_recv is None else metadata_recv
|
||||
|
||||
# Send and receive data
|
||||
recv_tensor_metadata = None if metadata_recv is None else metadata_recv.tensor_metadata
|
||||
recv_tensor_objs = _batch_send_recv_tensor(
|
||||
tensor_objs, recv_tensor_metadata, send_dst, recv_src, send_group, recv_group, current_device
|
||||
recv_tensor_objs, wait_handles = _batch_send_recv_tensor(
|
||||
tensor_objs,
|
||||
recv_tensor_metadata,
|
||||
send_dst,
|
||||
recv_src,
|
||||
send_group,
|
||||
recv_group,
|
||||
current_device,
|
||||
overlap_p2p=overlap_p2p,
|
||||
send_first=send_first if send_first != None else True,
|
||||
)
|
||||
|
||||
if metadata_recv is not None:
|
||||
@@ -424,33 +437,9 @@ def _communicate(
|
||||
for idx in non_tensor_obj_idx:
|
||||
recv_tensor_objs.insert(idx, non_tensor_objs.pop(0))
|
||||
recv_object = tree_unflatten(recv_tensor_objs, tree_spec)
|
||||
return recv_object, wait_handles
|
||||
|
||||
return recv_object
|
||||
|
||||
|
||||
def _send_object(object: Any, src: int, dst: int, group: ProcessGroup, **kwargs) -> None:
|
||||
"""send anything to dst rank
|
||||
|
||||
Args:
|
||||
object (Any): object needed to be sent
|
||||
dst (int): rank of the destination
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
_communicate(object, send_dst=dst, recv_src=None, send_group=group, **kwargs)
|
||||
|
||||
|
||||
def _recv_object(src: int, dst: int, group: ProcessGroup, **kwargs) -> Any:
|
||||
"""recv anything from src
|
||||
|
||||
Args:
|
||||
src (int): source rank of data. local rank will receive data from src rank.
|
||||
|
||||
Returns:
|
||||
Any: Object received from src.
|
||||
"""
|
||||
return _communicate(None, send_dst=None, recv_src=src, recv_group=group, **kwargs)
|
||||
return None, wait_handles
|
||||
|
||||
|
||||
def _p2p_comm(
|
||||
@@ -532,10 +521,13 @@ def _p2p_comm(
|
||||
|
||||
|
||||
class PipelineP2PCommunication:
|
||||
def __init__(self, stage_manager: PipelineStageManager) -> None:
|
||||
def __init__(self, stage_manager: PipelineStageManager, overlap_p2p: bool = True) -> None:
|
||||
self.stage_manager = stage_manager
|
||||
self.overlap_p2p = overlap_p2p
|
||||
|
||||
def recv_forward(self, prev_rank: Optional[int] = None, metadata_recv: Optional[P2PMetadata] = None) -> Any:
|
||||
def recv_forward(
|
||||
self, prev_rank: Optional[int] = None, metadata_recv: Optional[P2PMetadata] = None
|
||||
) -> Tuple[Any, List]:
|
||||
"""Copy the forward output from the previous stage in pipeline as the input tensor of this stage.
|
||||
|
||||
Args:
|
||||
@@ -543,95 +535,186 @@ class PipelineP2PCommunication:
|
||||
|
||||
Returns:
|
||||
Any: The input tensor or input tensor list.
|
||||
List: List of handles for the communication requests, if overlap is enabled.
|
||||
"""
|
||||
if prev_rank is None:
|
||||
prev_rank = self.stage_manager.get_prev_rank()
|
||||
cur_rank = self.stage_manager.get_rank()
|
||||
input_tensor = _recv_object(
|
||||
prev_rank,
|
||||
cur_rank,
|
||||
self.stage_manager.get_p2p_process_group(prev_rank, cur_rank),
|
||||
input_tensor, wait_handles = _communicate(
|
||||
object=None,
|
||||
recv_src=prev_rank,
|
||||
send_dst=None,
|
||||
recv_group=self.stage_manager.get_p2p_process_group(),
|
||||
metadata_recv=metadata_recv,
|
||||
overlap_p2p=self.overlap_p2p,
|
||||
)
|
||||
|
||||
return input_tensor
|
||||
return input_tensor, wait_handles
|
||||
|
||||
def recv_backward(self, next_rank: Optional[int] = None, metadata_recv: Optional[P2PMetadata] = None) -> Any:
|
||||
def recv_backward(
|
||||
self, next_rank: Optional[int] = None, metadata_recv: Optional[P2PMetadata] = None
|
||||
) -> Tuple[Any, List]:
|
||||
"""Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage.
|
||||
|
||||
Args:
|
||||
next_rank (int, optional): The rank of the source of the tensor.
|
||||
|
||||
Returns:
|
||||
Any: The input gradient tensor or gradient tensor list.
|
||||
Any: The input tensor or input tensor list.
|
||||
List: List of handles for the communication requests, if overlap is enabled.
|
||||
"""
|
||||
if next_rank is None:
|
||||
next_rank = self.stage_manager.get_next_rank()
|
||||
cur_rank = self.stage_manager.get_rank()
|
||||
output_tensor_grad = _recv_object(
|
||||
next_rank,
|
||||
cur_rank,
|
||||
self.stage_manager.get_p2p_process_group(next_rank, cur_rank),
|
||||
|
||||
output_tensor_grad, wait_handles = _communicate(
|
||||
object=None,
|
||||
recv_src=next_rank,
|
||||
send_dst=None,
|
||||
recv_group=self.stage_manager.get_p2p_process_group(),
|
||||
metadata_recv=metadata_recv,
|
||||
overlap_p2p=self.overlap_p2p,
|
||||
)
|
||||
|
||||
return output_tensor_grad
|
||||
return output_tensor_grad, wait_handles
|
||||
|
||||
def send_forward(self, output_object: Any, next_rank: Optional[int] = None, send_metadata: bool = True) -> None:
|
||||
def send_forward(self, output_object: Any, next_rank: Optional[int] = None, send_metadata: bool = True) -> List:
|
||||
"""Sends the input tensor to the next stage in pipeline.
|
||||
|
||||
Args:
|
||||
output_object (Any): Object to be sent.
|
||||
next_rank (int, optional): The rank of the recipient of the tensor.
|
||||
|
||||
Returns:
|
||||
List: List of handles for the communication requests, if overlap is enabled.
|
||||
"""
|
||||
if next_rank is None:
|
||||
next_rank = self.stage_manager.get_next_rank()
|
||||
cur_rank = self.stage_manager.get_rank()
|
||||
_send_object(
|
||||
_, handles = _communicate(
|
||||
output_object,
|
||||
cur_rank,
|
||||
next_rank,
|
||||
self.stage_manager.get_p2p_process_group(cur_rank, next_rank),
|
||||
recv_src=None,
|
||||
send_dst=next_rank,
|
||||
send_group=self.stage_manager.get_p2p_process_group(),
|
||||
send_metadata=send_metadata,
|
||||
overlap_p2p=self.overlap_p2p,
|
||||
)
|
||||
return handles
|
||||
|
||||
def send_backward(self, input_object: Any, prev_rank: Optional[int] = None, send_metadata: bool = True) -> None:
|
||||
def send_backward(self, input_object: Any, prev_rank: Optional[int] = None, send_metadata: bool = True) -> List:
|
||||
"""Sends the gradient tensor to the previous stage in pipeline.
|
||||
|
||||
Args:
|
||||
input_object (Any): Object to be sent.
|
||||
prev_rank (int, optional): The rank of the recipient of the tensor
|
||||
|
||||
Returns:
|
||||
List: List of handles for the communication requests, if overlap is enabled.
|
||||
"""
|
||||
if prev_rank is None:
|
||||
prev_rank = self.stage_manager.get_prev_rank()
|
||||
cur_rank = self.stage_manager.get_rank()
|
||||
_send_object(
|
||||
_, handles = _communicate(
|
||||
input_object,
|
||||
cur_rank,
|
||||
prev_rank,
|
||||
self.stage_manager.get_p2p_process_group(cur_rank, prev_rank),
|
||||
recv_src=None,
|
||||
send_dst=prev_rank,
|
||||
send_group=self.stage_manager.get_p2p_process_group(),
|
||||
send_metadata=send_metadata,
|
||||
overlap_p2p=self.overlap_p2p,
|
||||
)
|
||||
return handles
|
||||
|
||||
def send_forward_recv_forward(
|
||||
self,
|
||||
output_object: Any,
|
||||
is_send: bool,
|
||||
is_recv: bool,
|
||||
send_first: bool,
|
||||
send_metadata: bool = True,
|
||||
metadata_recv: Optional[P2PMetadata] = None,
|
||||
) -> Tuple[Any, List]:
|
||||
"""Sends the input tensor to the next pipeline stage and copy the output tensor from the next pipeline stage
|
||||
|
||||
Args:
|
||||
output_object (Any): Object to be sent.
|
||||
is_send (bool): Whether to send the input tensor to the next pipeline stage.
|
||||
is_recv (bool): Whether to copy the output tensor from the next pipeline stage.
|
||||
send_first (bool): Whether to send before receive.
|
||||
send_metadata (bool, optional): Whether to send metadata.
|
||||
metadata_recv (P2PMetadata, optional): The cached metadata(size, type) of the object to be received.
|
||||
|
||||
Returns:
|
||||
Any: The input tensor or input tensor list.
|
||||
List: List of handles for the communication requests, if overlap is enabled.
|
||||
"""
|
||||
next_rank = self.stage_manager.get_next_rank() if is_send else None
|
||||
prev_rank = self.stage_manager.get_prev_rank() if is_recv else None
|
||||
group = self.stage_manager.get_p2p_process_group()
|
||||
return _communicate(
|
||||
output_object,
|
||||
send_dst=next_rank,
|
||||
recv_src=prev_rank,
|
||||
send_group=group if is_send else None,
|
||||
recv_group=group if is_recv else None,
|
||||
send_metadata=send_metadata if is_send else False,
|
||||
metadata_recv=metadata_recv if is_recv else None,
|
||||
send_first=send_first,
|
||||
overlap_p2p=self.overlap_p2p,
|
||||
)
|
||||
|
||||
def send_backward_recv_backward(
|
||||
self,
|
||||
input_object: Any,
|
||||
is_send: bool,
|
||||
is_recv: bool,
|
||||
send_first: bool,
|
||||
send_metadata: bool = True,
|
||||
metadata_recv: Optional[P2PMetadata] = None,
|
||||
) -> Tuple[Any, List]:
|
||||
"""Sends the gradient tensor to the previous pipeline stage and copy the gradient tensor from the previous pipeline stage
|
||||
|
||||
Args:
|
||||
input_object (Any): Object to be sent.
|
||||
is_send (bool): Whether to send the gradient tensor to the previous pipeline stage.
|
||||
is_recv (bool): Whether to copy the gradient tensor from the previous pipeline stage.
|
||||
send_first (bool): Whether to send before receive.
|
||||
send_metadata (bool, optional): Whether to send metadata.
|
||||
metadata_recv (P2PMetadata, optional): The cached metadata(size, type) of the object to be received.
|
||||
|
||||
Returns:
|
||||
Any: The input tensor or input tensor list.
|
||||
List: List of handles for the communication requests, if overlap is enabled.
|
||||
"""
|
||||
prev_rank = self.stage_manager.get_prev_rank() if is_send else None
|
||||
next_rank = self.stage_manager.get_next_rank() if is_recv else None
|
||||
|
||||
group = self.stage_manager.get_p2p_process_group()
|
||||
|
||||
return _communicate(
|
||||
input_object,
|
||||
send_dst=prev_rank,
|
||||
recv_src=next_rank,
|
||||
send_group=group if is_send else None,
|
||||
recv_group=group if is_recv else None,
|
||||
send_metadata=send_metadata if is_send else False,
|
||||
metadata_recv=metadata_recv if is_recv else None,
|
||||
send_first=send_first,
|
||||
overlap_p2p=self.overlap_p2p,
|
||||
)
|
||||
|
||||
def send_forward_recv_backward(
|
||||
self,
|
||||
input_object: Any,
|
||||
next_rank: Optional[int] = None,
|
||||
send_metadata: bool = True,
|
||||
metadata_recv: Optional[P2PMetadata] = None,
|
||||
send_prior_fallback: Optional[bool] = None,
|
||||
) -> Any:
|
||||
"""Sends the gradient tensor to and copy the gradient tensor from the next stage in pipeline
|
||||
send_first: Optional[bool] = None,
|
||||
) -> Tuple[Any, List]:
|
||||
"""Sends the gradient tensor to and copy the gradient tensor from the next pipeline stage
|
||||
|
||||
Args:
|
||||
input_object (Any): Object to be sent.
|
||||
next_rank (int, optional): The rank of the sender and recipient of the tensor
|
||||
"""
|
||||
if next_rank is None:
|
||||
next_rank = self.stage_manager.get_next_rank()
|
||||
|
||||
cur_rank = self.stage_manager.get_rank()
|
||||
group = self.stage_manager.get_p2p_process_group(cur_rank, next_rank)
|
||||
Returns:
|
||||
Any: The input tensor or input tensor list.
|
||||
List: List of handles for the communication requests, if overlap is enabled.
|
||||
"""
|
||||
next_rank = self.stage_manager.get_next_rank()
|
||||
group = self.stage_manager.get_p2p_process_group()
|
||||
return _communicate(
|
||||
input_object,
|
||||
next_rank,
|
||||
@@ -640,28 +723,28 @@ class PipelineP2PCommunication:
|
||||
recv_group=group,
|
||||
send_metadata=send_metadata,
|
||||
metadata_recv=metadata_recv,
|
||||
send_prior_fallback=send_prior_fallback,
|
||||
send_first=send_first,
|
||||
overlap_p2p=False,
|
||||
)
|
||||
|
||||
def send_backward_recv_forward(
|
||||
self,
|
||||
input_object: Any,
|
||||
prev_rank: Optional[int] = None,
|
||||
send_metadata: bool = True,
|
||||
metadata_recv: Optional[P2PMetadata] = None,
|
||||
send_prior_fallback: Optional[bool] = None,
|
||||
) -> Any:
|
||||
send_first: Optional[bool] = None,
|
||||
) -> Tuple[Any, List]:
|
||||
"""Sends the gradient tensor to and copy the gradient tensor from the previous stage in pipeline
|
||||
|
||||
Args:
|
||||
input_object (Any): Object to be sent.
|
||||
prev_rank (int, optional): The rank of the sender and recipient of the tensor
|
||||
"""
|
||||
if prev_rank is None:
|
||||
prev_rank = self.stage_manager.get_prev_rank()
|
||||
|
||||
cur_rank = self.stage_manager.get_rank()
|
||||
group = self.stage_manager.get_p2p_process_group(prev_rank, cur_rank)
|
||||
Returns:
|
||||
Any: The input tensor or input tensor list.
|
||||
List: List of handles for the communication requests, if overlap is enabled.
|
||||
"""
|
||||
prev_rank = self.stage_manager.get_prev_rank()
|
||||
group = self.stage_manager.get_p2p_process_group()
|
||||
return _communicate(
|
||||
input_object,
|
||||
prev_rank,
|
||||
@@ -670,7 +753,8 @@ class PipelineP2PCommunication:
|
||||
recv_group=group,
|
||||
send_metadata=send_metadata,
|
||||
metadata_recv=metadata_recv,
|
||||
send_prior_fallback=send_prior_fallback,
|
||||
send_first=send_first,
|
||||
overlap_p2p=False,
|
||||
)
|
||||
|
||||
def p2p_communicate(
|
||||
@@ -679,7 +763,7 @@ class PipelineP2PCommunication:
|
||||
recv_pre: bool,
|
||||
next_rank: Optional[int] = None,
|
||||
comm_dtype: torch.dtype = torch.float16,
|
||||
) -> None:
|
||||
) -> Any:
|
||||
"""
|
||||
Sends the input tensor to the next stage in pipeline, using `P2Pop` in torch.
|
||||
|
||||
@@ -689,12 +773,11 @@ class PipelineP2PCommunication:
|
||||
"""
|
||||
if next_rank is None:
|
||||
next_rank = self.stage_manager.get_next_rank()
|
||||
cur_rank = self.stage_manager.get_rank()
|
||||
recv_tensor = _p2p_comm(
|
||||
output_object,
|
||||
recv_pre,
|
||||
next_rank,
|
||||
self.stage_manager.get_p2p_process_group(cur_rank, next_rank),
|
||||
self.stage_manager.get_p2p_process_group(),
|
||||
comm_dtype,
|
||||
)
|
||||
return recv_tensor
|
||||
|
Reference in New Issue
Block a user