[pipeline]: add p2p fallback order and fix interleaved pp deadlock (#5214)

* fix: add fallback order option and update 1f1b

* fix: fix deadlock comm in interleaved pp

* test: modify p2p test
This commit is contained in:
Wenhao Chen
2024-01-03 11:34:49 +08:00
committed by GitHub
parent 3c0d82b19b
commit d799a3088f
5 changed files with 269 additions and 136 deletions

View File

@@ -54,10 +54,10 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
# P2PMeta cache
self.enable_metadata_cache = enable_metadata_cache
self.send_metadata_forward = True
self.send_metadata_backward = True
self.metadata_recv_forward = None
self.metadata_recv_backward = None
self.send_tensor_metadata = True
self.send_grad_metadata = True
self.tensor_metadata_recv = None
self.grad_metadata_recv = None
def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None:
"""Load a batch from data iterator.
@@ -90,11 +90,11 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
# NOTE: disable metadata cache when batch size changes (not valid anymore)
if self.batch_size != self.last_batch_size:
self.enable_metadata_cache = False
self.send_metadata_forward = True
self.send_metadata_backward = True
self.metadata_recv_forward = None
self.metadata_recv_backward = None
self.send_tensor_metadata = True
self.send_grad_metadata = True
self.tensor_metadata_recv = None
self.grad_metadata_recv = None
self.last_batch_size = self.batch_size
def load_micro_batch(self) -> Any:
@@ -119,9 +119,9 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
Any: The input tensor or input tensor list.
"""
if not self.stage_manager.is_first_stage():
input_tensor = self.comm.recv_forward(prev_rank, metadata_recv=self.metadata_recv_forward)
if self.enable_metadata_cache and self.metadata_recv_forward is None:
self.metadata_recv_forward = create_fast_send_metadata(input_tensor)
input_tensor = self.comm.recv_forward(prev_rank, metadata_recv=self.tensor_metadata_recv)
if self.enable_metadata_cache and self.tensor_metadata_recv is None:
self.tensor_metadata_recv = create_fast_send_metadata(input_tensor)
return input_tensor
@@ -136,13 +136,13 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
Any: The input gradient tensor or gradient tensor list.
"""
if not self.stage_manager.is_last_stage():
output_tensor_grad = self.comm.recv_backward(next_rank, metadata_recv=self.metadata_recv_backward)
if self.enable_metadata_cache and self.metadata_recv_backward is None:
self.metadata_recv_backward = create_fast_send_metadata(output_tensor_grad)
output_tensor_grad = self.comm.recv_backward(next_rank, metadata_recv=self.grad_metadata_recv)
if self.enable_metadata_cache and self.grad_metadata_recv is None:
self.grad_metadata_recv = create_fast_send_metadata(output_tensor_grad)
return output_tensor_grad
def send_forward(self, output_object: Any, next_rank: int = None) -> None:
def send_forward(self, output_tensor: Any, next_rank: int = None) -> None:
"""Sends the input tensor to the next stage in pipeline.
For 1F1B.
@@ -151,10 +151,10 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
next_rank (int, optional): The rank of the recipient of the tensor.
"""
if not self.stage_manager.is_last_stage():
self.comm.send_forward(output_object, next_rank, send_metadata=self.send_metadata_forward)
self.send_metadata_forward = not self.enable_metadata_cache
self.comm.send_forward(output_tensor, next_rank, send_metadata=self.send_tensor_metadata)
self.send_tensor_metadata = not self.enable_metadata_cache
def send_backward(self, input_object: Any, prev_rank: int = None) -> None:
def send_backward(self, input_tensor_grad: Any, prev_rank: int = None) -> None:
"""Sends the gradient tensor to the previous stage in pipeline.
For 1F1B.
@@ -163,10 +163,12 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
prev_rank (int, optional): The rank of the recipient of the tensor
"""
if not self.stage_manager.is_first_stage():
self.comm.send_backward(input_object, prev_rank, send_metadata=self.send_metadata_backward)
self.send_metadata_backward = not self.enable_metadata_cache
self.comm.send_backward(input_tensor_grad, prev_rank, send_metadata=self.send_grad_metadata)
self.send_grad_metadata = not self.enable_metadata_cache
def send_forward_recv_backward(self, output_object: Any, next_rank: int = None) -> Any:
def send_forward_recv_backward(
self, output_tensor: Any, next_rank: int = None, send_prior_fallback: Optional[bool] = None
) -> Any:
"""Sends the input tensor to the next stage and copy the gradient tensor from the next stage in pipeline.
For 1F1B.
@@ -175,19 +177,24 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
next_rank (int, optional): The rank of the recipient of the tensor.
"""
if not self.stage_manager.is_last_stage():
if not self.send_tensor_metadata and self.grad_metadata_recv is not None:
send_prior_fallback = None # must not fallback
output_tensor_grad = self.comm.send_forward_recv_backward(
output_object,
output_tensor,
next_rank,
send_metadata=self.send_metadata_forward,
metadata_recv=self.metadata_recv_backward,
send_metadata=self.send_tensor_metadata,
metadata_recv=self.grad_metadata_recv,
send_prior_fallback=send_prior_fallback,
)
self.send_metadata_forward = not self.enable_metadata_cache
if self.enable_metadata_cache and self.metadata_recv_backward is None:
self.metadata_recv_backward = create_fast_send_metadata(output_tensor_grad)
self.send_tensor_metadata = not self.enable_metadata_cache
if self.enable_metadata_cache and self.grad_metadata_recv is None:
self.grad_metadata_recv = create_fast_send_metadata(output_tensor_grad)
return output_tensor_grad
def send_backward_recv_forward(self, output_object: Any, prev_rank: int = None) -> Any:
def send_backward_recv_forward(
self, input_tensor_grad: Any, prev_rank: int = None, send_prior_fallback: Optional[bool] = None
) -> Any:
"""Sends the gradient tensor to the previous stage and copy the input tensor from the previous stage in pipeline.
For 1F1B.
@@ -196,15 +203,18 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
prev_rank (int, optional): The rank of the recipient of the tensor.
"""
if not self.stage_manager.is_first_stage():
if not self.send_grad_metadata and self.tensor_metadata_recv is not None:
send_prior_fallback = None # must not fallback
input_tensor = self.comm.send_backward_recv_forward(
output_object,
input_tensor_grad,
prev_rank,
send_metadata=self.send_metadata_backward,
metadata_recv=self.metadata_recv_forward,
send_metadata=self.send_grad_metadata,
metadata_recv=self.tensor_metadata_recv,
send_prior_fallback=send_prior_fallback,
)
self.send_metadata_backward = not self.enable_metadata_cache
if self.enable_metadata_cache and self.metadata_recv_forward is None:
self.metadata_recv_forward = create_fast_send_metadata(input_tensor)
self.send_grad_metadata = not self.enable_metadata_cache
if self.enable_metadata_cache and self.tensor_metadata_recv is None:
self.tensor_metadata_recv = create_fast_send_metadata(input_tensor)
return input_tensor
@@ -365,7 +375,9 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
last_iteration = i == (num_microbatches_remaining - 1)
output_obj = self.forward_step(model, input_obj, criterion, accum_loss, outputs)
output_obj_grad = self.send_forward_recv_backward(output_obj)
output_obj_grad = self.send_forward_recv_backward(
output_obj, send_prior_fallback=self.stage_manager.stage % 2 == 0
)
# Add input_obj and output_obj to end of list.
input_objs.append(input_obj)
output_objs.append(output_obj)
@@ -379,7 +391,9 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
if last_iteration:
self.send_backward(input_obj_grad)
else:
input_obj = self.send_backward_recv_forward(input_obj_grad)
input_obj = self.send_backward_recv_forward(
input_obj_grad, send_prior_fallback=self.stage_manager.stage % 2 == 0
)
# Run cooldown backward passes.
for i in range(num_warmup_microbatches):