[Feature] optimize PP overlap (#5735)

* update to fully overlap, still debugging

* improve interface

* fixed deadlock bug

* debug NaN loss

* (experimental) use one comm group for send_fw_recv_fw to fix NaN

* cleaned up interfaces; use one batch p2p for all

* clean up; removed the double p2p batch case

* p2p test passsed

* improve overlap: send fwd before backward

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* tentatively use 2 p2p batches

* remove two p2p batches

* fix typos

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* remove pp.sh

---------

Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: root <root@notebook-c55824c0-7742-45e8-9591-c855bb77ad29-0.notebook-c55824c0-7742-45e8-9591-c855bb77ad29.colossal-ai.svc.cluster.local>
This commit is contained in:
Edenzzzz
2024-06-26 14:48:02 +08:00
committed by GitHub
parent 4ccaaaab63
commit 2a25a2aff7
9 changed files with 457 additions and 358 deletions

View File

@@ -1,8 +1,9 @@
from functools import partial
from typing import Any, Callable, Dict, Iterable, List, Optional, Union
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
import torch
import torch.cuda
import torch.distributed
from torch.nn import Module, ModuleList
from torch.utils._pytree import tree_map
@@ -16,6 +17,12 @@ from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, model_
from .base import PipelineSchedule
def _wait_p2p(wait_handles: List[torch.cuda.Event]) -> None:
if wait_handles is not None:
for req in wait_handles:
req.wait()
class InterleavedSchedule(PipelineSchedule):
def __init__(
self,
@@ -24,13 +31,15 @@ class InterleavedSchedule(PipelineSchedule):
num_microbatch: Optional[int] = None,
microbatch_size: Optional[int] = None,
enable_metadata_cache: bool = True,
overlap_p2p: bool = True,
) -> None:
super().__init__(stage_manager)
assert (
num_microbatch is not None or microbatch_size is not None
), "Either num_microbatch or microbatch_size should be provided"
self.comm = PipelineP2PCommunication(stage_manager)
self.comm = PipelineP2PCommunication(stage_manager, overlap_p2p=overlap_p2p)
self.overlap_p2p = overlap_p2p
self.num_microbatch = num_microbatch
self.microbatch_size = microbatch_size
self.num_model_chunks = num_model_chunks
@@ -113,14 +122,17 @@ class InterleavedSchedule(PipelineSchedule):
Returns:
int: The model chunk idx of the input microbatch_id
"""
assert microbatch_id < self.num_microbatch * self.num_model_chunks
assert (
microbatch_id < self.num_microbatch * self.num_model_chunks
), f"microbatch_id {microbatch_id} is out of range ({self.num_microbatch * self.num_model_chunks})"
microbatch_id_in_group = microbatch_id % (self.stage_manager.num_stages * self.num_model_chunks)
model_chunk_id = microbatch_id_in_group // self.stage_manager.num_stages
if not is_forward:
# Reverse order
model_chunk_id = self.num_model_chunks - model_chunk_id - 1
return model_chunk_id
def recv_forward(self, model_chunk_id: int, prev_rank: int = None) -> Any:
def recv_forward(self, model_chunk_id: int, prev_rank: int = None) -> Tuple[Any, List]:
"""Copy the forward output from the previous stage in pipeline as the input tensor of this stage.
For interleaved 1F1B.
@@ -130,16 +142,19 @@ class InterleavedSchedule(PipelineSchedule):
Returns:
Any: The input tensor or input tensor list.
Any: The wait handles for the communication.
"""
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
if not self.stage_manager.is_first_stage():
input_tensor = self.comm.recv_forward(prev_rank, metadata_recv=self.tensor_metadata_recv)
input_tensor, wait_handles = self.comm.recv_forward(prev_rank, metadata_recv=self.tensor_metadata_recv)
if self.enable_metadata_cache and self.tensor_metadata_recv is None:
self.tensor_metadata_recv = create_send_metadata(input_tensor)
return input_tensor
return input_tensor, wait_handles
return None, []
def recv_backward(self, model_chunk_id: int, next_rank: int = None) -> Any:
def recv_backward(self, model_chunk_id: int, next_rank: int = None) -> Tuple[Any, List]:
"""Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage.
For interleaved 1F1B.
@@ -149,16 +164,20 @@ class InterleavedSchedule(PipelineSchedule):
Returns:
Any: The input gradient tensor or gradient tensor list.
Any: The wait handles for the communication.
"""
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
if not self.stage_manager.is_last_stage():
output_tensor_grad = self.comm.recv_backward(next_rank, metadata_recv=self.grad_metadata_recv)
output_tensor_grad, wait_handles = self.comm.recv_backward(
next_rank, metadata_recv=self.grad_metadata_recv
)
if self.enable_metadata_cache and self.grad_metadata_recv is None:
self.grad_metadata_recv = create_send_metadata(output_tensor_grad)
return output_tensor_grad, wait_handles
return output_tensor_grad
return None, []
def send_forward(self, model_chunk_id: int, output_tensor: Any, next_rank: int = None) -> None:
def send_forward(self, model_chunk_id: int, output_tensor: Any, next_rank: int = None) -> List:
"""Sends the input tensor to the next stage in pipeline.
For interleaved 1F1B.
@@ -166,13 +185,18 @@ class InterleavedSchedule(PipelineSchedule):
model_chunk_id (int): The current model chunk idx.
output_object (Any): Object to be sent.
next_rank (int, optional): The rank of the recipient of the tensor.
Returns:
Any: The wait handles for the communication.
"""
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
if not self.stage_manager.is_last_stage():
self.comm.send_forward(output_tensor, next_rank, send_metadata=self.send_tensor_metadata)
send_handles = self.comm.send_forward(output_tensor, next_rank, send_metadata=self.send_tensor_metadata)
self.send_tensor_metadata = not self.enable_metadata_cache
return send_handles
return []
def send_backward(self, model_chunk_id: int, input_tensor_grad: Any, prev_rank: int = None) -> None:
def send_backward(self, model_chunk_id: int, input_tensor_grad: Any, prev_rank: int = None) -> List:
"""Sends the gradient tensor to the previous stage in pipeline.
For interleaved 1F1B.
@@ -180,99 +204,61 @@ class InterleavedSchedule(PipelineSchedule):
model_chunk_id (int): The current model chunk idx.
input_object (Any): Object to be sent.
prev_rank (int, optional): The rank of the recipient of the tensor
Returns:
Any: The wait handles for the communication.
"""
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
if not self.stage_manager.is_first_stage():
self.comm.send_backward(input_tensor_grad, prev_rank, send_metadata=self.send_grad_metadata)
send_handles = self.comm.send_backward(
input_tensor_grad, prev_rank, send_metadata=self.send_grad_metadata
)
self.send_grad_metadata = not self.enable_metadata_cache
def send_forward_recv_backward(
self,
model_chunk_id_send: int,
model_chunk_id_recv: int,
output_tensor: Any,
next_rank: Optional[int] = None,
send_prior_fallback: Optional[bool] = None,
) -> Any:
with self.stage_manager.switch_model_chunk_id(model_chunk_id_send):
send_data = not self.stage_manager.is_last_stage()
with self.stage_manager.switch_model_chunk_id(model_chunk_id_recv):
recv_data = not self.stage_manager.is_last_stage()
if send_data and recv_data:
if not self.send_forward_recv_backward and self.grad_metadata_recv is not None:
send_prior_fallback = None # must not fallback
output_tensor_grad = self.comm.send_forward_recv_backward(
output_tensor,
next_rank,
send_metadata=self.send_tensor_metadata,
metadata_recv=self.grad_metadata_recv,
send_prior_fallback=send_prior_fallback,
)
self.send_tensor_metadata = not self.enable_metadata_cache
if self.enable_metadata_cache and self.grad_metadata_recv is None:
self.grad_metadata_recv = create_send_metadata(output_tensor_grad)
return output_tensor_grad
# send only or recv only
self.send_forward(model_chunk_id_send, output_tensor)
return self.recv_backward(model_chunk_id_recv)
def send_backward_recv_forward(
self,
model_chunk_id_send: int,
model_chunk_id_recv: int,
input_tensor_grad: Any,
prev_rank: Optional[int] = None,
send_prior_fallback: Optional[bool] = None,
) -> Any:
with self.stage_manager.switch_model_chunk_id(model_chunk_id_send):
send_data = not self.stage_manager.is_first_stage()
with self.stage_manager.switch_model_chunk_id(model_chunk_id_recv):
recv_data = not self.stage_manager.is_first_stage()
if send_data and recv_data:
if not self.send_backward_recv_backward and self.tensor_metadata_recv is not None:
send_prior_fallback = None # must not fallback
input_tensor = self.comm.send_backward_recv_forward(
input_tensor_grad,
prev_rank,
send_metadata=self.send_grad_metadata,
metadata_recv=self.tensor_metadata_recv,
send_prior_fallback=send_prior_fallback,
)
self.send_grad_metadata = not self.enable_metadata_cache
if self.enable_metadata_cache and self.tensor_metadata_recv is None:
self.tensor_metadata_recv = create_send_metadata(input_tensor)
return input_tensor
# send only or recv only
self.send_backward(model_chunk_id_send, input_tensor_grad)
return self.recv_forward(model_chunk_id_recv)
return send_handles
return []
def send_forward_recv_forward(
self, model_chunk_id_send: int, model_chunk_id_recv: int, output_tensor: Any, send_prior: bool
):
if send_prior:
self.send_forward(model_chunk_id_send, output_tensor)
input_tensor = self.recv_forward(model_chunk_id_recv)
else:
input_tensor = self.recv_forward(model_chunk_id_recv)
self.send_forward(model_chunk_id_send, output_tensor)
self, model_chunk_id_send: int, model_chunk_id_recv: int, output_tensor: Any, send_first: bool = True
) -> Tuple[Any, List]:
with self.stage_manager.switch_model_chunk_id(model_chunk_id_send):
is_send = not self.stage_manager.is_last_stage()
with self.stage_manager.switch_model_chunk_id(model_chunk_id_recv):
is_recv = not self.stage_manager.is_first_stage()
input_tensor, wait_handles = self.comm.send_forward_recv_forward(
output_tensor,
is_send,
is_recv,
send_metadata=self.send_tensor_metadata,
metadata_recv=self.tensor_metadata_recv,
send_first=send_first,
)
# Cache metadata
self.send_tensor_metadata = not self.enable_metadata_cache and is_send
if is_recv and self.enable_metadata_cache and self.tensor_metadata_recv is None:
self.tensor_metadata_recv = create_send_metadata(input_tensor)
return input_tensor
return input_tensor, wait_handles
def send_backward_recv_backward(
self, model_chunk_id_send: int, model_chunk_id_recv: int, input_tensor_grad: Any, send_prior: bool
):
if send_prior:
self.send_backward(model_chunk_id_send, input_tensor_grad)
output_tensor_grad = self.recv_backward(model_chunk_id_recv)
else:
output_tensor_grad = self.recv_backward(model_chunk_id_recv)
self.send_backward(model_chunk_id_send, input_tensor_grad)
return output_tensor_grad
self, model_chunk_id_send: int, model_chunk_id_recv: int, input_tensor_grad: Any, send_first: bool = True
) -> Tuple[Any, List]:
with self.stage_manager.switch_model_chunk_id(model_chunk_id_send):
is_send = not self.stage_manager.is_first_stage()
with self.stage_manager.switch_model_chunk_id(model_chunk_id_recv):
is_recv = not self.stage_manager.is_last_stage()
output_tensor_grad, wait_handles = self.comm.send_backward_recv_backward(
input_tensor_grad,
is_send,
is_recv,
send_metadata=self.send_grad_metadata,
metadata_recv=self.grad_metadata_recv,
send_first=send_first,
)
# Cache metadata
self.send_grad_metadata = not self.enable_metadata_cache and is_send
if is_recv and self.enable_metadata_cache and self.grad_metadata_recv is None:
self.grad_metadata_recv = create_send_metadata(output_tensor_grad)
return output_tensor_grad, wait_handles
def forward_step(
self,
@@ -294,10 +280,12 @@ class InterleavedSchedule(PipelineSchedule):
Returns:
Union[torch.Tensor, dict]: The intermediate output (dict) of the current stage. If it is the last stage, the output is the loss (Tensor).
"""
# Load input ids, attention mask and labels
micro_batch = self.load_micro_batch(model_chunk_id=model_chunk_id)
# for the first stage, input_obj is None
# for the non-first stage, input_obj is the output of the previous stage and it's must be a dict
# for other stages, input_obj is the output of the previous stage containing hidden_states etc.
# Only attention_mask from micro_batch is used
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
if isinstance(model_chunk, ModuleList):
@@ -381,23 +369,27 @@ class InterleavedSchedule(PipelineSchedule):
if return_loss and self.stage_manager.is_last_stage(ignore_chunk=True):
accum_loss = torch.scalar_tensor(0, device=get_current_device())
fwd_wait_handles = []
model_chunk_id = self.get_model_chunk_id(0, is_forward=True)
input_obj = self.recv_forward(model_chunk_id)
input_obj, fwd_wait_handles = self.recv_forward(model_chunk_id)
for i in range(self.num_microbatch * self.num_model_chunks):
last_iteration = i == self.num_microbatch * self.num_model_chunks - 1
last_batch = i == self.num_microbatch * self.num_model_chunks - 1
model_chunk_id = self.get_model_chunk_id(i, is_forward=True)
# Wait until current input is received
_wait_p2p(fwd_wait_handles)
output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)
if not last_iteration:
input_obj = self.send_forward_recv_forward(
if not last_batch:
input_obj, fwd_wait_handles = self.send_forward_recv_forward(
model_chunk_id_send=model_chunk_id,
model_chunk_id_recv=self.get_model_chunk_id(i + 1, is_forward=True),
output_tensor=output_obj,
send_prior=self.stage_manager.stage % 2 == 0,
send_first=self.stage_manager.stage % 2 == 0,
)
else:
self.send_forward(model_chunk_id, output_obj)
fwd_wait_handles = self.send_forward(model_chunk_id, output_obj)
if outputs is not None:
outputs = merge_batch(outputs)
@@ -420,7 +412,9 @@ class InterleavedSchedule(PipelineSchedule):
self.load_batch(data_iter)
num_microbatch = self.num_microbatch * self.num_model_chunks
# Forward + until 1st backward
num_warmup_microbatch = (self.stage_manager.num_stages - self.stage_manager.stage - 1) * 2
# Steps needed to reach the last chunk
num_warmup_microbatch += (self.num_model_chunks - 1) * self.stage_manager.num_stages
num_warmup_microbatch = min(num_warmup_microbatch, num_microbatch)
num_microbatch_remaining = num_microbatch - num_warmup_microbatch
@@ -435,35 +429,44 @@ class InterleavedSchedule(PipelineSchedule):
if return_loss and self.stage_manager.is_last_stage(ignore_chunk=True):
accum_loss = torch.scalar_tensor(0, device=get_current_device())
bwd_wait_handles = []
# Get the 1st input batch
model_chunk_id = self.get_model_chunk_id(0, is_forward=True)
input_obj = self.recv_forward(model_chunk_id)
input_obj, fwd_wait_handles = self.recv_forward(model_chunk_id)
# Run warmup forward passes.
for i in range(num_warmup_microbatch):
last_iteration = i == num_warmup_microbatch - 1
last_batch = i == num_warmup_microbatch - 1
model_chunk_id = self.get_model_chunk_id(i, is_forward=True)
# Wait for input
_wait_p2p(fwd_wait_handles)
output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)
input_objs[model_chunk_id].append(input_obj)
output_objs[model_chunk_id].append(output_obj)
if last_iteration and num_microbatch_remaining == 0:
self.send_forward(model_chunk_id, output_obj)
if last_batch and num_microbatch_remaining == 0:
fwd_wait_handles = self.send_forward(model_chunk_id, output_obj)
else:
input_obj = self.send_forward_recv_forward(
input_obj, fwd_wait_handles = self.send_forward_recv_forward(
model_chunk_id_send=model_chunk_id,
model_chunk_id_recv=self.get_model_chunk_id(i + 1, is_forward=True),
output_tensor=output_obj,
send_prior=self.stage_manager.stage % 2 == 0,
send_first=self.stage_manager.stage % 2 == 0,
)
if num_microbatch_remaining > 0:
model_chunk_id = self.get_model_chunk_id(0, is_forward=False)
output_obj_grad = self.recv_backward(model_chunk_id)
output_obj_grad, bwd_wait_handles = self.recv_backward(model_chunk_id)
# Run 1F1B in steady state.
for i in range(num_microbatch_remaining):
last_iteration = i == num_microbatch_remaining - 1
fwd_batch_id = i + num_warmup_microbatch
last_batch = i == num_microbatch_remaining - 1
model_chunk_id = self.get_model_chunk_id(fwd_batch_id, is_forward=True)
model_chunk_id = self.get_model_chunk_id(i + num_warmup_microbatch, is_forward=True)
# Wait for input.
_wait_p2p(fwd_wait_handles)
output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)
# Add input_obj and output_obj to end of list.
input_objs[model_chunk_id].append(input_obj)
@@ -473,64 +476,75 @@ class InterleavedSchedule(PipelineSchedule):
# Pop output_obj and output_obj from the start of the list for the backward pass.
_input_obj = input_objs[model_chunk_id].pop(0)
_output_obj = output_objs[model_chunk_id].pop(0)
input_obj_grad = self.backward_step(optimizer, _input_obj, _output_obj, output_obj_grad)
# NOTE: perform 2x communication for forward and backward
def send_forward_recv_backward():
if last_iteration and num_microbatch == num_microbatch_remaining:
model_chunk_id = self.get_model_chunk_id(i + num_warmup_microbatch, is_forward=True)
self.send_forward(model_chunk_id, output_obj)
# Helper functions
def send_forward_recv_forward():
if last_batch:
model_chunk_id = self.get_model_chunk_id(fwd_batch_id, is_forward=True)
wait_handles = self.send_forward(model_chunk_id, output_obj)
return None, wait_handles
else:
output_obj_grad = self.send_forward_recv_backward(
model_chunk_id_send=self.get_model_chunk_id(i + num_warmup_microbatch, is_forward=True),
model_chunk_id_recv=self.get_model_chunk_id(i + 1, is_forward=False),
input_obj, wait_handles = self.send_forward_recv_forward(
model_chunk_id_send=self.get_model_chunk_id(fwd_batch_id, is_forward=True),
model_chunk_id_recv=self.get_model_chunk_id(fwd_batch_id + 1, is_forward=True),
output_tensor=output_obj,
send_prior_fallback=self.stage_manager.stage % 2 == 0,
send_first=self.stage_manager.stage % 2 == 0
and i > 0, # Receive from warmup stage first in the first batch
)
return output_obj_grad
return input_obj, wait_handles
def send_backward_recv_forward():
if last_iteration:
def send_backward_recv_backward():
no_cooldown = num_microbatch == num_microbatch_remaining
if last_batch and no_cooldown:
model_chunk_id = self.get_model_chunk_id(i, is_forward=False)
self.send_backward(model_chunk_id, input_obj_grad)
wait_handles = self.send_backward(model_chunk_id, input_obj_grad)
return None, wait_handles
else:
input_obj = self.send_backward_recv_forward(
output_obj_grad, wait_handles = self.send_backward_recv_backward(
model_chunk_id_send=self.get_model_chunk_id(i, is_forward=False),
model_chunk_id_recv=self.get_model_chunk_id(i + num_warmup_microbatch + 1, is_forward=True),
model_chunk_id_recv=self.get_model_chunk_id(i + 1, is_forward=False),
input_tensor_grad=input_obj_grad,
send_prior_fallback=self.stage_manager.stage % 2 == 0 and i > 0,
send_first=self.stage_manager.stage % 2 == 0,
)
return input_obj
return output_obj_grad, wait_handles
if self.stage_manager.stage % 2 == 0:
output_obj_grad = send_forward_recv_backward()
input_obj = send_backward_recv_forward()
else:
input_obj = send_backward_recv_forward()
output_obj_grad = send_forward_recv_backward()
input_obj, fwd_wait_handles = send_forward_recv_forward()
# Wait for upstream grad
_wait_p2p(bwd_wait_handles)
input_obj_grad = self.backward_step(optimizer, _input_obj, _output_obj, output_obj_grad)
# NOTE: It's documented by NCCL that running two concurrent communicators (batch_isend_irecv)
# risks deadlock (https://docs.nvidia.com/deeplearning/nccl/archives/nccl_2134/user-guide/docs/usage/communicators.html)
# however in practice this works fine, and Megatron does this too
# (https://github.com/microsoft/Megatron-DeepSpeed/blob/bcedecd1ff788d4d363f3365fd396053a08d65be/megatron/core/pipeline_parallel/schedules.py#L774)
# if deadlock, call _wait_p2p(fwd_wait_handles) here
output_obj_grad, bwd_wait_handles = send_backward_recv_backward()
if num_microbatch_remaining == 0:
model_chunk_id = self.get_model_chunk_id(0, is_forward=False)
output_obj_grad = self.recv_backward(model_chunk_id)
output_obj_grad, bwd_wait_handles = self.recv_backward(model_chunk_id)
# Run cooldown backward passes.
for i in range(num_microbatch_remaining, num_microbatch):
last_iteration = i == num_microbatch - 1
last_batch = i == num_microbatch - 1
model_chunk_id = self.get_model_chunk_id(i, is_forward=False)
_input_obj = input_objs[model_chunk_id].pop(0)
_output_obj = output_objs[model_chunk_id].pop(0)
# output_obj_grad = self.recv_backward(model_chunk_id)
input_obj_grad = self.backward_step(optimizer, _input_obj, _output_obj, output_obj_grad)
if not last_iteration:
output_obj_grad = self.send_backward_recv_backward(
# Wait for upstream grad
_wait_p2p(bwd_wait_handles)
# backward local grads
input_obj_grad = self.backward_step(optimizer, _input_obj, _output_obj, output_obj_grad)
if not last_batch:
output_obj_grad, bwd_wait_handles = self.send_backward_recv_backward(
model_chunk_id_send=self.get_model_chunk_id(i, is_forward=False),
model_chunk_id_recv=self.get_model_chunk_id(i + 1, is_forward=False),
input_tensor_grad=input_obj_grad,
send_prior=self.stage_manager.stage % 2 == 0 and i > num_microbatch_remaining,
send_first=self.stage_manager.stage % 2 == 0 and i > num_microbatch_remaining,
)
assert (not self.overlap_p2p) or len(bwd_wait_handles) > 0
else:
model_chunk_id = self.get_model_chunk_id(i, is_forward=False)
self.send_backward(model_chunk_id, input_obj_grad)
_ = self.send_backward(model_chunk_id, input_obj_grad)
assert all(len(v) == 0 for v in input_objs) and all(len(v) == 0 for v in output_objs)