[pipeline]: add p2p fallback order and fix interleaved pp deadlock (#5214)

* fix: add fallback order option and update 1f1b

* fix: fix deadlock comm in interleaved pp

* test: modify p2p test
This commit is contained in:
Wenhao Chen
2024-01-03 11:34:49 +08:00
committed by GitHub
parent 3c0d82b19b
commit d799a3088f
5 changed files with 269 additions and 136 deletions

View File

@@ -344,6 +344,7 @@ def _communicate(
recv_group: Optional[ProcessGroup] = None,
send_metadata: bool = True,
metadata_recv: Optional[P2PMetadata] = None,
send_prior_fallback: Optional[bool] = None,
) -> Any:
"""
Send and receive object from send_dst and recv_src respectively
@@ -368,8 +369,14 @@ def _communicate(
# NOTE: send & recv should be atomic operations. However, if we need to send metadata or receive metadata,
# we are not able to do that (1. send & recv metadata 2. send & recv). So we need to split the send & recv into two parts in this case.
if (send_dst is not None and recv_src is not None) and (send_metadata or metadata_recv is None):
_communicate(object, send_dst=send_dst, recv_src=None, send_group=send_group, send_metadata=send_metadata)
return _communicate(None, send_dst=None, recv_src=recv_src, recv_group=recv_group, metadata_recv=metadata_recv)
assert send_prior_fallback is not None, "Priority must be set if fallback happens"
if send_prior_fallback:
_communicate(object, send_dst=send_dst, recv_src=None, send_group=send_group, send_metadata=send_metadata)
return _communicate(None, send_dst=None, recv_src=recv_src, recv_group=recv_group, metadata_recv=metadata_recv)
else:
recv_data = _communicate(None, send_dst=None, recv_src=recv_src, recv_group=recv_group, metadata_recv=metadata_recv)
_communicate(object, send_dst=send_dst, recv_src=None, send_group=send_group, send_metadata=send_metadata)
return recv_data
# NOTE: only the following 5 cases are valid:
# 1. send() [needs extra metadata] and no recv()
@@ -437,7 +444,7 @@ def _communicate(
raise ValueError("Unknown data type {}".format(metadata_recv.data_type))
def _send_object(object: Any, src: int, dst: int, group: ProcessGroup, send_metadata: bool) -> None:
def _send_object(object: Any, src: int, dst: int, group: ProcessGroup, **kwargs) -> None:
"""send anything to dst rank
Args:
@@ -447,10 +454,10 @@ def _send_object(object: Any, src: int, dst: int, group: ProcessGroup, send_meta
Returns:
None
"""
_communicate(object, send_dst=dst, recv_src=None, send_group=group, send_metadata=send_metadata)
_communicate(object, send_dst=dst, recv_src=None, send_group=group, **kwargs)
def _recv_object(src: int, dst: int, group: ProcessGroup, metadata_recv: Optional[P2PMetadata]) -> Any:
def _recv_object(src: int, dst: int, group: ProcessGroup, **kwargs) -> Any:
"""recv anything from src
Args:
@@ -459,7 +466,7 @@ def _recv_object(src: int, dst: int, group: ProcessGroup, metadata_recv: Optiona
Returns:
Any: Object received from src.
"""
return _communicate(None, send_dst=None, recv_src=src, recv_group=group, metadata_recv=metadata_recv)
return _communicate(None, send_dst=None, recv_src=src, recv_group=group, **kwargs)
def _p2p_comm(
@@ -557,7 +564,10 @@ class PipelineP2PCommunication:
prev_rank = self.stage_manager.get_prev_rank()
cur_rank = self.stage_manager.get_rank()
input_tensor = _recv_object(
prev_rank, cur_rank, self.stage_manager.get_p2p_process_group(prev_rank, cur_rank), metadata_recv
prev_rank,
cur_rank,
self.stage_manager.get_p2p_process_group(prev_rank, cur_rank),
metadata_recv=metadata_recv,
)
return input_tensor
@@ -575,7 +585,10 @@ class PipelineP2PCommunication:
next_rank = self.stage_manager.get_next_rank()
cur_rank = self.stage_manager.get_rank()
output_tensor_grad = _recv_object(
next_rank, cur_rank, self.stage_manager.get_p2p_process_group(next_rank, cur_rank), metadata_recv
next_rank,
cur_rank,
self.stage_manager.get_p2p_process_group(next_rank, cur_rank),
metadata_recv=metadata_recv,
)
return output_tensor_grad
@@ -595,7 +608,7 @@ class PipelineP2PCommunication:
cur_rank,
next_rank,
self.stage_manager.get_p2p_process_group(cur_rank, next_rank),
send_metadata,
send_metadata=send_metadata,
)
def send_backward(self, input_object: Any, prev_rank: Optional[int] = None, send_metadata: bool = True) -> None:
@@ -613,7 +626,7 @@ class PipelineP2PCommunication:
cur_rank,
prev_rank,
self.stage_manager.get_p2p_process_group(cur_rank, prev_rank),
send_metadata,
send_metadata=send_metadata,
)
def send_forward_recv_backward(
@@ -622,6 +635,7 @@ class PipelineP2PCommunication:
next_rank: Optional[int] = None,
send_metadata: bool = True,
metadata_recv: Optional[P2PMetadata] = None,
send_prior_fallback: Optional[bool] = None,
) -> Any:
"""Sends the gradient tensor to and copy the gradient tensor from the next stage in pipeline
@@ -642,6 +656,7 @@ class PipelineP2PCommunication:
recv_group=group,
send_metadata=send_metadata,
metadata_recv=metadata_recv,
send_prior_fallback=send_prior_fallback,
)
def send_backward_recv_forward(
@@ -650,6 +665,7 @@ class PipelineP2PCommunication:
prev_rank: Optional[int] = None,
send_metadata: bool = True,
metadata_recv: Optional[P2PMetadata] = None,
send_prior_fallback: Optional[bool] = None,
) -> Any:
"""Sends the gradient tensor to and copy the gradient tensor from the previous stage in pipeline
@@ -670,6 +686,7 @@ class PipelineP2PCommunication:
recv_group=group,
send_metadata=send_metadata,
metadata_recv=metadata_recv,
send_prior_fallback=send_prior_fallback,
)
def p2p_communicate(