[Feature] optimize PP overlap (#5735)

* update to fully overlap, still debugging

* improve interface

* fixed deadlock bug

* debug NaN loss

* (experimental) use one comm group for send_fw_recv_fw to fix NaN

* cleaned up interfaces; use one batch p2p for all

* clean up; removed the double p2p batch case

* p2p test passsed

* improve overlap: send fwd before backward

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* tentatively use 2 p2p batches

* remove two p2p batches

* fix typos

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* remove pp.sh

---------

Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: root <root@notebook-c55824c0-7742-45e8-9591-c855bb77ad29-0.notebook-c55824c0-7742-45e8-9591-c855bb77ad29.colossal-ai.svc.cluster.local>
This commit is contained in:
Edenzzzz
2024-06-26 14:48:02 +08:00
committed by GitHub
parent 4ccaaaab63
commit 2a25a2aff7
9 changed files with 457 additions and 358 deletions

View File

@@ -225,31 +225,41 @@ def _batch_send_recv_tensor(
send_group: Optional[ProcessGroup],
recv_group: Optional[ProcessGroup],
current_device: Any,
overlap_p2p: bool = True,
send_first: bool = True,
) -> Optional[Union[torch.Tensor, List[torch.Tensor]]]:
buffer_recv = None
if recv_tensor_metadata is not None:
buffer_recv = _create_recv_buffer(recv_tensor_metadata, current_device)
ops = []
if send_dst is not None and send_tensor_list is not None:
assert send_group is not None
_filling_ops_queue(send_tensor_list, dist.isend, send_dst, ops, send_group)
if recv_src is not None and buffer_recv is not None:
assert recv_group is not None
_filling_ops_queue(buffer_recv, dist.irecv, recv_src, ops, recv_group)
is_send = send_dst is not None and send_tensor_list is not None
is_recv = recv_src is not None and buffer_recv is not None
if send_first:
if is_send:
assert send_group is not None
_filling_ops_queue(send_tensor_list, dist.isend, send_dst, ops, send_group)
if is_recv:
assert recv_group is not None
_filling_ops_queue(buffer_recv, dist.irecv, recv_src, ops, recv_group)
else:
if is_recv:
assert recv_group is not None
_filling_ops_queue(buffer_recv, dist.irecv, recv_src, ops, recv_group)
if is_send:
assert send_group is not None
_filling_ops_queue(send_tensor_list, dist.isend, send_dst, ops, send_group)
if len(ops) > 0:
reqs = dist.batch_isend_irecv(ops)
for req in reqs:
req.wait()
# Remove synchronization according to Pytorch's documentation
# However, the Megatron-LM does synchronization here
# https://github.com/microsoft/Megatron-DeepSpeed/blob/ef13d099c2a1609225a4ce4c1a1753cc76dd90a1/megatron/p2p_communication.py#L111-L112
# In case there is potential error, uncomment the following `torch.cuda.synchronize()`
# torch.cuda.synchronize()
return buffer_recv
if not overlap_p2p:
for req in reqs:
req.wait()
return buffer_recv, []
else:
return buffer_recv, reqs
return None, []
def _send_recv_serialization_object(
@@ -260,10 +270,11 @@ def _send_recv_serialization_object(
recv_group: Optional[ProcessGroup],
current_device: Any,
is_nccl_backend: bool,
send_first: bool = True,
) -> Optional[P2PMetadata]:
ops = []
send_object_tensor = None
send_object_size_tensor = None
if object is not None and send_dst is not None:
if Version(torch.__version__) >= Version("1.13.0"):
send_object_tensor, send_object_size_tensor = c10d._object_to_tensor(object, device=current_device)
@@ -274,43 +285,54 @@ def _send_recv_serialization_object(
send_object_size_tensor = send_object_size_tensor.to(current_device)
send_object_tensor = send_object_tensor.to(current_device)
_filling_ops_queue(send_object_size_tensor, dist.isend, send_dst, ops, send_group)
recv_object_size_tensor = None
if recv_src is not None:
recv_object_size_tensor = torch.empty(1, dtype=torch.long)
if is_nccl_backend:
recv_object_size_tensor = recv_object_size_tensor.to(current_device)
_filling_ops_queue(recv_object_size_tensor, dist.irecv, recv_src, ops, recv_group)
if send_first:
if send_object_size_tensor is not None:
_filling_ops_queue(send_object_size_tensor, dist.isend, send_dst, ops, send_group)
if recv_src is not None:
_filling_ops_queue(recv_object_size_tensor, dist.irecv, recv_src, ops, recv_group)
else:
if recv_src is not None:
_filling_ops_queue(recv_object_size_tensor, dist.irecv, recv_src, ops, recv_group)
if send_object_size_tensor is not None:
_filling_ops_queue(send_object_size_tensor, dist.isend, send_dst, ops, send_group)
if len(ops) > 0:
reqs = dist.batch_isend_irecv(ops)
for req in reqs:
req.wait()
# See the comment in `_batch_send_recv_tensor`
# torch.cuda.synchronize()
req.wait() # This blocks the compute stream in torch
ops = []
if send_dst is not None and send_object_tensor is not None:
_filling_ops_queue(send_object_tensor, dist.isend, send_dst, ops, send_group)
is_send = send_dst is not None and send_object_tensor is not None
is_recv = recv_src is not None and recv_object_size_tensor is not None
recv_object_tensor = None
if recv_src is not None and recv_object_size_tensor is not None:
if is_recv:
recv_object_tensor = torch.empty(recv_object_size_tensor.item(), dtype=torch.uint8)
if is_nccl_backend:
recv_object_tensor = recv_object_tensor.to(current_device)
_filling_ops_queue(recv_object_tensor, dist.irecv, recv_src, ops, recv_group)
if send_first:
if is_send:
_filling_ops_queue(send_object_tensor, dist.isend, send_dst, ops, send_group)
if is_recv:
_filling_ops_queue(recv_object_tensor, dist.irecv, recv_src, ops, recv_group)
else:
if is_recv:
_filling_ops_queue(recv_object_tensor, dist.irecv, recv_src, ops, recv_group)
if is_send:
_filling_ops_queue(send_object_tensor, dist.isend, send_dst, ops, send_group)
if len(ops) > 0:
reqs = dist.batch_isend_irecv(ops)
for req in reqs:
req.wait()
# See the comment in `_batch_send_recv_tensor`
# torch.cuda.synchronize()
if recv_object_tensor is not None and recv_object_size_tensor is not None:
recv_object_tensor = recv_object_tensor.type(torch.uint8)
if recv_object_tensor.device != torch.device("cpu"):
@@ -328,11 +350,12 @@ def _communicate(
object: Any,
send_dst: Optional[int],
recv_src: Optional[int],
overlap_p2p: bool,
send_group: Optional[ProcessGroup] = None,
recv_group: Optional[ProcessGroup] = None,
send_metadata: bool = True,
metadata_recv: Optional[P2PMetadata] = None,
send_prior_fallback: Optional[bool] = None,
send_first: Optional[bool] = None,
) -> Any:
"""
Send and receive object from send_dst and recv_src respectively
@@ -341,6 +364,7 @@ def _communicate(
object (Any): object needed to be sent
send_dst (int): rank of the destination
recv_src (int): rank of the source
overlap_p2p (bool): whether to overlap p2p communication with computation
send_group (ProcessGroup, optional): process group of sender
recv_group (ProcessGroup, optional): process group of receiver
send_metadata (bool, optional): whether to send metadata
@@ -358,32 +382,10 @@ def _communicate(
# NOTE: if object contains non-tensor objects, we have to send metadata
metadata_send, tensor_objs = create_send_metadata(object, strict=False, return_tensor=True)
send_metadata = send_metadata or len(metadata_send.non_tensor_obj_idx) > 0
else:
send_metadata = False
# NOTE: send & recv should be atomic operations. However, if we need to send metadata or receive metadata,
# we are not able to do that (1. send & recv metadata 2. send & recv). So we need to split the send & recv into two parts in this case.
if (send_dst is not None and recv_src is not None) and (send_metadata or metadata_recv is None):
assert send_prior_fallback is not None, "Priority must be set if fallback happens"
if send_prior_fallback:
_communicate(object, send_dst=send_dst, recv_src=None, send_group=send_group, send_metadata=send_metadata)
return _communicate(
None, send_dst=None, recv_src=recv_src, recv_group=recv_group, metadata_recv=metadata_recv
)
else:
recv_data = _communicate(
None, send_dst=None, recv_src=recv_src, recv_group=recv_group, metadata_recv=metadata_recv
)
_communicate(object, send_dst=send_dst, recv_src=None, send_group=send_group, send_metadata=send_metadata)
return recv_data
# NOTE: only the following 5 cases are valid:
# 1. send() [needs extra metadata] and no recv()
# 2. recv() [needs extra metadata] and no send()
# 3. neither send() nor recv() need extra metadata
assert not (send_dst is not None and send_metadata) or recv_src is None
assert not (recv_src is not None and metadata_recv is None) or send_dst is None
assert not (send_dst is not None and recv_src is not None) or (not send_metadata and metadata_recv is not None)
assert not c10d._rank_not_in_group(send_group) and not c10d._rank_not_in_group(recv_group)
current_send_device, is_send_nccl_backend = _check_device(send_group)
current_recv_device, is_recv_nccl_backend = _check_device(recv_group)
@@ -402,14 +404,25 @@ def _communicate(
recv_group=recv_group if metadata_recv is None else None,
current_device=current_device,
is_nccl_backend=is_nccl_backend,
send_first=send_first if send_first != None else True,
)
assert metadata_recv is None or _metadata_recv is None
assert (
metadata_recv is None or _metadata_recv is None
), "You shouldn't receive metadata when using the cached metadata"
metadata_recv = _metadata_recv if metadata_recv is None else metadata_recv
# Send and receive data
recv_tensor_metadata = None if metadata_recv is None else metadata_recv.tensor_metadata
recv_tensor_objs = _batch_send_recv_tensor(
tensor_objs, recv_tensor_metadata, send_dst, recv_src, send_group, recv_group, current_device
recv_tensor_objs, wait_handles = _batch_send_recv_tensor(
tensor_objs,
recv_tensor_metadata,
send_dst,
recv_src,
send_group,
recv_group,
current_device,
overlap_p2p=overlap_p2p,
send_first=send_first if send_first != None else True,
)
if metadata_recv is not None:
@@ -424,33 +437,9 @@ def _communicate(
for idx in non_tensor_obj_idx:
recv_tensor_objs.insert(idx, non_tensor_objs.pop(0))
recv_object = tree_unflatten(recv_tensor_objs, tree_spec)
return recv_object, wait_handles
return recv_object
def _send_object(object: Any, src: int, dst: int, group: ProcessGroup, **kwargs) -> None:
"""send anything to dst rank
Args:
object (Any): object needed to be sent
dst (int): rank of the destination
Returns:
None
"""
_communicate(object, send_dst=dst, recv_src=None, send_group=group, **kwargs)
def _recv_object(src: int, dst: int, group: ProcessGroup, **kwargs) -> Any:
"""recv anything from src
Args:
src (int): source rank of data. local rank will receive data from src rank.
Returns:
Any: Object received from src.
"""
return _communicate(None, send_dst=None, recv_src=src, recv_group=group, **kwargs)
return None, wait_handles
def _p2p_comm(
@@ -532,10 +521,13 @@ def _p2p_comm(
class PipelineP2PCommunication:
def __init__(self, stage_manager: PipelineStageManager) -> None:
def __init__(self, stage_manager: PipelineStageManager, overlap_p2p: bool = True) -> None:
self.stage_manager = stage_manager
self.overlap_p2p = overlap_p2p
def recv_forward(self, prev_rank: Optional[int] = None, metadata_recv: Optional[P2PMetadata] = None) -> Any:
def recv_forward(
self, prev_rank: Optional[int] = None, metadata_recv: Optional[P2PMetadata] = None
) -> Tuple[Any, List]:
"""Copy the forward output from the previous stage in pipeline as the input tensor of this stage.
Args:
@@ -543,95 +535,186 @@ class PipelineP2PCommunication:
Returns:
Any: The input tensor or input tensor list.
List: List of handles for the communication requests, if overlap is enabled.
"""
if prev_rank is None:
prev_rank = self.stage_manager.get_prev_rank()
cur_rank = self.stage_manager.get_rank()
input_tensor = _recv_object(
prev_rank,
cur_rank,
self.stage_manager.get_p2p_process_group(prev_rank, cur_rank),
input_tensor, wait_handles = _communicate(
object=None,
recv_src=prev_rank,
send_dst=None,
recv_group=self.stage_manager.get_p2p_process_group(),
metadata_recv=metadata_recv,
overlap_p2p=self.overlap_p2p,
)
return input_tensor
return input_tensor, wait_handles
def recv_backward(self, next_rank: Optional[int] = None, metadata_recv: Optional[P2PMetadata] = None) -> Any:
def recv_backward(
self, next_rank: Optional[int] = None, metadata_recv: Optional[P2PMetadata] = None
) -> Tuple[Any, List]:
"""Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage.
Args:
next_rank (int, optional): The rank of the source of the tensor.
Returns:
Any: The input gradient tensor or gradient tensor list.
Any: The input tensor or input tensor list.
List: List of handles for the communication requests, if overlap is enabled.
"""
if next_rank is None:
next_rank = self.stage_manager.get_next_rank()
cur_rank = self.stage_manager.get_rank()
output_tensor_grad = _recv_object(
next_rank,
cur_rank,
self.stage_manager.get_p2p_process_group(next_rank, cur_rank),
output_tensor_grad, wait_handles = _communicate(
object=None,
recv_src=next_rank,
send_dst=None,
recv_group=self.stage_manager.get_p2p_process_group(),
metadata_recv=metadata_recv,
overlap_p2p=self.overlap_p2p,
)
return output_tensor_grad
return output_tensor_grad, wait_handles
def send_forward(self, output_object: Any, next_rank: Optional[int] = None, send_metadata: bool = True) -> None:
def send_forward(self, output_object: Any, next_rank: Optional[int] = None, send_metadata: bool = True) -> List:
"""Sends the input tensor to the next stage in pipeline.
Args:
output_object (Any): Object to be sent.
next_rank (int, optional): The rank of the recipient of the tensor.
Returns:
List: List of handles for the communication requests, if overlap is enabled.
"""
if next_rank is None:
next_rank = self.stage_manager.get_next_rank()
cur_rank = self.stage_manager.get_rank()
_send_object(
_, handles = _communicate(
output_object,
cur_rank,
next_rank,
self.stage_manager.get_p2p_process_group(cur_rank, next_rank),
recv_src=None,
send_dst=next_rank,
send_group=self.stage_manager.get_p2p_process_group(),
send_metadata=send_metadata,
overlap_p2p=self.overlap_p2p,
)
return handles
def send_backward(self, input_object: Any, prev_rank: Optional[int] = None, send_metadata: bool = True) -> None:
def send_backward(self, input_object: Any, prev_rank: Optional[int] = None, send_metadata: bool = True) -> List:
"""Sends the gradient tensor to the previous stage in pipeline.
Args:
input_object (Any): Object to be sent.
prev_rank (int, optional): The rank of the recipient of the tensor
Returns:
List: List of handles for the communication requests, if overlap is enabled.
"""
if prev_rank is None:
prev_rank = self.stage_manager.get_prev_rank()
cur_rank = self.stage_manager.get_rank()
_send_object(
_, handles = _communicate(
input_object,
cur_rank,
prev_rank,
self.stage_manager.get_p2p_process_group(cur_rank, prev_rank),
recv_src=None,
send_dst=prev_rank,
send_group=self.stage_manager.get_p2p_process_group(),
send_metadata=send_metadata,
overlap_p2p=self.overlap_p2p,
)
return handles
def send_forward_recv_forward(
self,
output_object: Any,
is_send: bool,
is_recv: bool,
send_first: bool,
send_metadata: bool = True,
metadata_recv: Optional[P2PMetadata] = None,
) -> Tuple[Any, List]:
"""Sends the input tensor to the next pipeline stage and copy the output tensor from the next pipeline stage
Args:
output_object (Any): Object to be sent.
is_send (bool): Whether to send the input tensor to the next pipeline stage.
is_recv (bool): Whether to copy the output tensor from the next pipeline stage.
send_first (bool): Whether to send before receive.
send_metadata (bool, optional): Whether to send metadata.
metadata_recv (P2PMetadata, optional): The cached metadata(size, type) of the object to be received.
Returns:
Any: The input tensor or input tensor list.
List: List of handles for the communication requests, if overlap is enabled.
"""
next_rank = self.stage_manager.get_next_rank() if is_send else None
prev_rank = self.stage_manager.get_prev_rank() if is_recv else None
group = self.stage_manager.get_p2p_process_group()
return _communicate(
output_object,
send_dst=next_rank,
recv_src=prev_rank,
send_group=group if is_send else None,
recv_group=group if is_recv else None,
send_metadata=send_metadata if is_send else False,
metadata_recv=metadata_recv if is_recv else None,
send_first=send_first,
overlap_p2p=self.overlap_p2p,
)
def send_backward_recv_backward(
self,
input_object: Any,
is_send: bool,
is_recv: bool,
send_first: bool,
send_metadata: bool = True,
metadata_recv: Optional[P2PMetadata] = None,
) -> Tuple[Any, List]:
"""Sends the gradient tensor to the previous pipeline stage and copy the gradient tensor from the previous pipeline stage
Args:
input_object (Any): Object to be sent.
is_send (bool): Whether to send the gradient tensor to the previous pipeline stage.
is_recv (bool): Whether to copy the gradient tensor from the previous pipeline stage.
send_first (bool): Whether to send before receive.
send_metadata (bool, optional): Whether to send metadata.
metadata_recv (P2PMetadata, optional): The cached metadata(size, type) of the object to be received.
Returns:
Any: The input tensor or input tensor list.
List: List of handles for the communication requests, if overlap is enabled.
"""
prev_rank = self.stage_manager.get_prev_rank() if is_send else None
next_rank = self.stage_manager.get_next_rank() if is_recv else None
group = self.stage_manager.get_p2p_process_group()
return _communicate(
input_object,
send_dst=prev_rank,
recv_src=next_rank,
send_group=group if is_send else None,
recv_group=group if is_recv else None,
send_metadata=send_metadata if is_send else False,
metadata_recv=metadata_recv if is_recv else None,
send_first=send_first,
overlap_p2p=self.overlap_p2p,
)
def send_forward_recv_backward(
self,
input_object: Any,
next_rank: Optional[int] = None,
send_metadata: bool = True,
metadata_recv: Optional[P2PMetadata] = None,
send_prior_fallback: Optional[bool] = None,
) -> Any:
"""Sends the gradient tensor to and copy the gradient tensor from the next stage in pipeline
send_first: Optional[bool] = None,
) -> Tuple[Any, List]:
"""Sends the gradient tensor to and copy the gradient tensor from the next pipeline stage
Args:
input_object (Any): Object to be sent.
next_rank (int, optional): The rank of the sender and recipient of the tensor
"""
if next_rank is None:
next_rank = self.stage_manager.get_next_rank()
cur_rank = self.stage_manager.get_rank()
group = self.stage_manager.get_p2p_process_group(cur_rank, next_rank)
Returns:
Any: The input tensor or input tensor list.
List: List of handles for the communication requests, if overlap is enabled.
"""
next_rank = self.stage_manager.get_next_rank()
group = self.stage_manager.get_p2p_process_group()
return _communicate(
input_object,
next_rank,
@@ -640,28 +723,28 @@ class PipelineP2PCommunication:
recv_group=group,
send_metadata=send_metadata,
metadata_recv=metadata_recv,
send_prior_fallback=send_prior_fallback,
send_first=send_first,
overlap_p2p=False,
)
def send_backward_recv_forward(
self,
input_object: Any,
prev_rank: Optional[int] = None,
send_metadata: bool = True,
metadata_recv: Optional[P2PMetadata] = None,
send_prior_fallback: Optional[bool] = None,
) -> Any:
send_first: Optional[bool] = None,
) -> Tuple[Any, List]:
"""Sends the gradient tensor to and copy the gradient tensor from the previous stage in pipeline
Args:
input_object (Any): Object to be sent.
prev_rank (int, optional): The rank of the sender and recipient of the tensor
"""
if prev_rank is None:
prev_rank = self.stage_manager.get_prev_rank()
cur_rank = self.stage_manager.get_rank()
group = self.stage_manager.get_p2p_process_group(prev_rank, cur_rank)
Returns:
Any: The input tensor or input tensor list.
List: List of handles for the communication requests, if overlap is enabled.
"""
prev_rank = self.stage_manager.get_prev_rank()
group = self.stage_manager.get_p2p_process_group()
return _communicate(
input_object,
prev_rank,
@@ -670,7 +753,8 @@ class PipelineP2PCommunication:
recv_group=group,
send_metadata=send_metadata,
metadata_recv=metadata_recv,
send_prior_fallback=send_prior_fallback,
send_first=send_first,
overlap_p2p=False,
)
def p2p_communicate(
@@ -679,7 +763,7 @@ class PipelineP2PCommunication:
recv_pre: bool,
next_rank: Optional[int] = None,
comm_dtype: torch.dtype = torch.float16,
) -> None:
) -> Any:
"""
Sends the input tensor to the next stage in pipeline, using `P2Pop` in torch.
@@ -689,12 +773,11 @@ class PipelineP2PCommunication:
"""
if next_rank is None:
next_rank = self.stage_manager.get_next_rank()
cur_rank = self.stage_manager.get_rank()
recv_tensor = _p2p_comm(
output_object,
recv_pre,
next_rank,
self.stage_manager.get_p2p_process_group(cur_rank, next_rank),
self.stage_manager.get_p2p_process_group(),
comm_dtype,
)
return recv_tensor