mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 11:02:05 +00:00
[pipeline]: add p2p fallback order and fix interleaved pp deadlock (#5214)
* fix: add fallback order option and update 1f1b * fix: fix deadlock comm in interleaved pp * test: modify p2p test
This commit is contained in:
@@ -344,6 +344,7 @@ def _communicate(
|
||||
recv_group: Optional[ProcessGroup] = None,
|
||||
send_metadata: bool = True,
|
||||
metadata_recv: Optional[P2PMetadata] = None,
|
||||
send_prior_fallback: Optional[bool] = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Send and receive object from send_dst and recv_src respectively
|
||||
@@ -368,8 +369,14 @@ def _communicate(
|
||||
# NOTE: send & recv should be atomic operations. However, if we need to send metadata or receive metadata,
|
||||
# we are not able to do that (1. send & recv metadata 2. send & recv). So we need to split the send & recv into two parts in this case.
|
||||
if (send_dst is not None and recv_src is not None) and (send_metadata or metadata_recv is None):
|
||||
_communicate(object, send_dst=send_dst, recv_src=None, send_group=send_group, send_metadata=send_metadata)
|
||||
return _communicate(None, send_dst=None, recv_src=recv_src, recv_group=recv_group, metadata_recv=metadata_recv)
|
||||
assert send_prior_fallback is not None, "Priority must be set if fallback happens"
|
||||
if send_prior_fallback:
|
||||
_communicate(object, send_dst=send_dst, recv_src=None, send_group=send_group, send_metadata=send_metadata)
|
||||
return _communicate(None, send_dst=None, recv_src=recv_src, recv_group=recv_group, metadata_recv=metadata_recv)
|
||||
else:
|
||||
recv_data = _communicate(None, send_dst=None, recv_src=recv_src, recv_group=recv_group, metadata_recv=metadata_recv)
|
||||
_communicate(object, send_dst=send_dst, recv_src=None, send_group=send_group, send_metadata=send_metadata)
|
||||
return recv_data
|
||||
|
||||
# NOTE: only the following 5 cases are valid:
|
||||
# 1. send() [needs extra metadata] and no recv()
|
||||
@@ -437,7 +444,7 @@ def _communicate(
|
||||
raise ValueError("Unknown data type {}".format(metadata_recv.data_type))
|
||||
|
||||
|
||||
def _send_object(object: Any, src: int, dst: int, group: ProcessGroup, send_metadata: bool) -> None:
|
||||
def _send_object(object: Any, src: int, dst: int, group: ProcessGroup, **kwargs) -> None:
|
||||
"""send anything to dst rank
|
||||
|
||||
Args:
|
||||
@@ -447,10 +454,10 @@ def _send_object(object: Any, src: int, dst: int, group: ProcessGroup, send_meta
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
_communicate(object, send_dst=dst, recv_src=None, send_group=group, send_metadata=send_metadata)
|
||||
_communicate(object, send_dst=dst, recv_src=None, send_group=group, **kwargs)
|
||||
|
||||
|
||||
def _recv_object(src: int, dst: int, group: ProcessGroup, metadata_recv: Optional[P2PMetadata]) -> Any:
|
||||
def _recv_object(src: int, dst: int, group: ProcessGroup, **kwargs) -> Any:
|
||||
"""recv anything from src
|
||||
|
||||
Args:
|
||||
@@ -459,7 +466,7 @@ def _recv_object(src: int, dst: int, group: ProcessGroup, metadata_recv: Optiona
|
||||
Returns:
|
||||
Any: Object received from src.
|
||||
"""
|
||||
return _communicate(None, send_dst=None, recv_src=src, recv_group=group, metadata_recv=metadata_recv)
|
||||
return _communicate(None, send_dst=None, recv_src=src, recv_group=group, **kwargs)
|
||||
|
||||
|
||||
def _p2p_comm(
|
||||
@@ -557,7 +564,10 @@ class PipelineP2PCommunication:
|
||||
prev_rank = self.stage_manager.get_prev_rank()
|
||||
cur_rank = self.stage_manager.get_rank()
|
||||
input_tensor = _recv_object(
|
||||
prev_rank, cur_rank, self.stage_manager.get_p2p_process_group(prev_rank, cur_rank), metadata_recv
|
||||
prev_rank,
|
||||
cur_rank,
|
||||
self.stage_manager.get_p2p_process_group(prev_rank, cur_rank),
|
||||
metadata_recv=metadata_recv,
|
||||
)
|
||||
|
||||
return input_tensor
|
||||
@@ -575,7 +585,10 @@ class PipelineP2PCommunication:
|
||||
next_rank = self.stage_manager.get_next_rank()
|
||||
cur_rank = self.stage_manager.get_rank()
|
||||
output_tensor_grad = _recv_object(
|
||||
next_rank, cur_rank, self.stage_manager.get_p2p_process_group(next_rank, cur_rank), metadata_recv
|
||||
next_rank,
|
||||
cur_rank,
|
||||
self.stage_manager.get_p2p_process_group(next_rank, cur_rank),
|
||||
metadata_recv=metadata_recv,
|
||||
)
|
||||
|
||||
return output_tensor_grad
|
||||
@@ -595,7 +608,7 @@ class PipelineP2PCommunication:
|
||||
cur_rank,
|
||||
next_rank,
|
||||
self.stage_manager.get_p2p_process_group(cur_rank, next_rank),
|
||||
send_metadata,
|
||||
send_metadata=send_metadata,
|
||||
)
|
||||
|
||||
def send_backward(self, input_object: Any, prev_rank: Optional[int] = None, send_metadata: bool = True) -> None:
|
||||
@@ -613,7 +626,7 @@ class PipelineP2PCommunication:
|
||||
cur_rank,
|
||||
prev_rank,
|
||||
self.stage_manager.get_p2p_process_group(cur_rank, prev_rank),
|
||||
send_metadata,
|
||||
send_metadata=send_metadata,
|
||||
)
|
||||
|
||||
def send_forward_recv_backward(
|
||||
@@ -622,6 +635,7 @@ class PipelineP2PCommunication:
|
||||
next_rank: Optional[int] = None,
|
||||
send_metadata: bool = True,
|
||||
metadata_recv: Optional[P2PMetadata] = None,
|
||||
send_prior_fallback: Optional[bool] = None,
|
||||
) -> Any:
|
||||
"""Sends the gradient tensor to and copy the gradient tensor from the next stage in pipeline
|
||||
|
||||
@@ -642,6 +656,7 @@ class PipelineP2PCommunication:
|
||||
recv_group=group,
|
||||
send_metadata=send_metadata,
|
||||
metadata_recv=metadata_recv,
|
||||
send_prior_fallback=send_prior_fallback,
|
||||
)
|
||||
|
||||
def send_backward_recv_forward(
|
||||
@@ -650,6 +665,7 @@ class PipelineP2PCommunication:
|
||||
prev_rank: Optional[int] = None,
|
||||
send_metadata: bool = True,
|
||||
metadata_recv: Optional[P2PMetadata] = None,
|
||||
send_prior_fallback: Optional[bool] = None,
|
||||
) -> Any:
|
||||
"""Sends the gradient tensor to and copy the gradient tensor from the previous stage in pipeline
|
||||
|
||||
@@ -670,6 +686,7 @@ class PipelineP2PCommunication:
|
||||
recv_group=group,
|
||||
send_metadata=send_metadata,
|
||||
metadata_recv=metadata_recv,
|
||||
send_prior_fallback=send_prior_fallback,
|
||||
)
|
||||
|
||||
def p2p_communicate(
|
||||
|
Reference in New Issue
Block a user