[pipeline]: fix p2p comm, add metadata cache and support llama interleaved pp (#5134)

* test: add more p2p tests

* fix: remove send_forward_recv_forward as p2p op list need to use the same group

* fix: make send and receive atomic

* feat: update P2PComm fn

* feat: add metadata cache in 1f1b

* feat: add metadata cache in interleaved pp

* feat: modify is_xx_stage fn

* revert: add _broadcast_object_list

* feat: add interleaved pp in llama policy

* feat: set NCCL_BUFFSIZE in HybridParallelPlugin
This commit is contained in:
Wenhao Chen
2023-12-22 10:44:00 +08:00
committed by GitHub
parent af952673f7
commit 4fa689fca1
15 changed files with 728 additions and 446 deletions

View File

@@ -7,7 +7,7 @@ from torch.nn import Module
from torch.utils._pytree import tree_map
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.pipeline.p2p import PipelineP2PCommunication
from colossalai.pipeline.p2p import PipelineP2PCommunication, create_fast_send_metadata
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.utils.device import get_current_device
@@ -42,14 +42,22 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
assert (
num_microbatches is not None or microbatch_size is not None
), "Either num_microbatches or microbatch_size should be provided"
self.comm = PipelineP2PCommunication(stage_manager)
self.num_microbatches = num_microbatches
self.microbatch_size = microbatch_size
self.batch: Optional[Any] = None
self.batch_size: Optional[int] = None
self.last_batch_size: Optional[int] = None
self.microbatch_offset: Optional[int] = None
self._use_microbatch_size = num_microbatches is None
# P2PMeta cache
self.send_metadata_forward = True
self.send_metadata_backward = True
self.metadata_recv_forward = None
self.metadata_recv_backward = None
def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None:
"""Load a batch from data iterator.
@@ -60,8 +68,14 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
batch = next(data_iter)
if device is not None:
batch = tree_map(partial(to_device, device=device), batch)
self.batch = batch
self.batch_size = get_batch_size(batch)
if self.last_batch_size is None:
self.last_batch_size = self.batch_size
else:
assert self.forward_only or self.last_batch_size == self.batch_size
# TODO: support arbitrary batch size when forward_only=True
self.microbatch_offset = 0
if not self._use_microbatch_size:
assert (
@@ -92,12 +106,12 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
Returns:
Any: The input tensor or input tensor list.
"""
if self.stage_manager.is_first_stage():
input_tensor = None
else:
input_tensor = self.comm.recv_forward(prev_rank)
if not self.stage_manager.is_first_stage():
input_tensor = self.comm.recv_forward(prev_rank, metadata_recv=self.metadata_recv_forward)
if self.metadata_recv_forward is None:
self.metadata_recv_forward = create_fast_send_metadata(input_tensor)
return input_tensor
return input_tensor
def recv_backward(self, next_rank: int = None) -> Any:
"""Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage.
@@ -109,12 +123,12 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
Returns:
Any: The input gradient tensor or gradient tensor list.
"""
if self.stage_manager.is_last_stage():
output_tensor_grad = None
else:
output_tensor_grad = self.comm.recv_backward(next_rank)
if not self.stage_manager.is_last_stage():
output_tensor_grad = self.comm.recv_backward(next_rank, metadata_recv=self.metadata_recv_backward)
if self.metadata_recv_backward is None:
self.metadata_recv_backward = create_fast_send_metadata(output_tensor_grad)
return output_tensor_grad
return output_tensor_grad
def send_forward(self, output_object: Any, next_rank: int = None) -> None:
"""Sends the input tensor to the next stage in pipeline.
@@ -125,18 +139,8 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
next_rank (int, optional): The rank of the recipient of the tensor.
"""
if not self.stage_manager.is_last_stage():
self.comm.send_forward(output_object, next_rank)
def send_forward_recv_backward(self, output_object: Any, next_rank: int = None) -> Any:
"""Sends the input tensor to the next stage and copy the gradient tensor from the next stage in pipeline.
For 1F1B.
Args:
output_object (Any): Object to be sent.
next_rank (int, optional): The rank of the recipient of the tensor.
"""
if not self.stage_manager.is_last_stage():
return self.comm.send_forward_recv_backward(output_object, next_rank)
self.comm.send_forward(output_object, next_rank, send_metadata=self.send_metadata_forward)
self.send_metadata_forward = False
def send_backward(self, input_object: Any, prev_rank: int = None) -> None:
"""Sends the gradient tensor to the previous stage in pipeline.
@@ -147,7 +151,29 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
prev_rank (int, optional): The rank of the recipient of the tensor
"""
if not self.stage_manager.is_first_stage():
self.comm.send_backward(input_object, prev_rank)
self.comm.send_backward(input_object, prev_rank, send_metadata=self.send_metadata_backward)
self.send_metadata_backward = False
def send_forward_recv_backward(self, output_object: Any, next_rank: int = None) -> Any:
"""Sends the input tensor to the next stage and copy the gradient tensor from the next stage in pipeline.
For 1F1B.
Args:
output_object (Any): Object to be sent.
next_rank (int, optional): The rank of the recipient of the tensor.
"""
if not self.stage_manager.is_last_stage():
output_tensor_grad = self.comm.send_forward_recv_backward(
output_object,
next_rank,
send_metadata=self.send_metadata_forward,
metadata_recv=self.metadata_recv_backward,
)
self.send_metadata_forward = False
if self.metadata_recv_backward is None:
self.metadata_recv_backward = create_fast_send_metadata(output_tensor_grad)
return output_tensor_grad
def send_backward_recv_forward(self, output_object: Any, prev_rank: int = None) -> Any:
"""Sends the gradient tensor to the previous stage and copy the input tensor from the previous stage in pipeline.
@@ -158,23 +184,17 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
prev_rank (int, optional): The rank of the recipient of the tensor.
"""
if not self.stage_manager.is_first_stage():
return self.comm.send_backward_recv_forward(output_object, prev_rank)
input_tensor = self.comm.send_backward_recv_forward(
output_object,
prev_rank,
send_metadata=self.send_metadata_backward,
metadata_recv=self.metadata_recv_forward,
)
self.send_metadata_backward = False
if self.metadata_recv_forward is None:
self.metadata_recv_forward = create_fast_send_metadata(input_tensor)
def send_forward_recv_forward(self, input_object: Any, prev_rank: int = None, next_rank: int = None) -> Any:
"""Sends the input tensor to the next stage and copy the input tensor from the previous stage in pipeline.
For 1F1B.
Args:
input_object (Any): Object to be sent.
prev_rank (int, optional): The previous rank of the recipient of the tensor.
next_rank (int, optional): The next rank of the recipient of the tensor.
"""
if self.stage_manager.is_first_stage():
return self.comm.send_forward(input_object, next_rank)
elif self.stage_manager.is_last_stage():
return self.comm.recv_forward(prev_rank)
else:
return self.comm.send_forward_recv_forward(input_object, prev_rank, next_rank)
return input_tensor
def forward_step(
self,
@@ -276,9 +296,10 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
Returns:
dict: A dict with keys: 'loss' and 'outputs'.
"""
forward_only = not torch.is_grad_enabled()
self.forward_only = not torch.is_grad_enabled()
if optimizer is None:
assert forward_only, "Optimizer should be passed when doing backward."
assert self.forward_only, "Optimizer should be passed when doing backward."
self.load_batch(data_iter)
@@ -291,25 +312,22 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
input_objs = None
output_objs = None
if not forward_only:
if not self.forward_only:
input_objs = []
output_objs = []
outputs = [] if return_outputs and self.stage_manager.is_last_stage() else None
accum_loss = None
if return_loss and self.stage_manager.is_last_stage():
accum_loss = torch.zeros(1, device=get_current_device())
else:
accum_loss = None
outputs = [] if return_outputs and self.stage_manager.is_last_stage() else None
# Run warmup forward passes.
for i in range(num_warmup_microbatches):
input_obj = self.recv_forward()
output_obj = self.forward_step(model, input_obj, criterion, accum_loss, outputs)
self.send_forward(output_obj)
if not forward_only:
if not self.forward_only:
input_objs.append(input_obj)
output_objs.append(output_obj)
@@ -324,16 +342,15 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
last_iteration = i == (num_microbatches_remaining - 1)
output_obj = self.forward_step(model, input_obj, criterion, accum_loss, outputs)
if forward_only:
if self.forward_only:
self.send_forward(output_obj)
if not last_iteration:
input_obj = self.recv_forward()
else:
# TODO adjust here
self.send_forward(output_obj)
output_obj_grad = self.recv_backward()
else:
output_obj_grad = self.send_forward_recv_backward(output_obj)
# Add input_obj and output_obj to end of list.
input_objs.append(input_obj)
output_objs.append(output_obj)
@@ -345,13 +362,12 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad)
if last_iteration:
input_obj = None
self.send_backward(input_obj_grad)
else:
input_obj = self.recv_forward()
self.send_backward(input_obj_grad)
input_obj = self.send_backward_recv_forward(input_obj_grad)
# Run cooldown backward passes.
if not forward_only:
if not self.forward_only:
for i in range(num_warmup_microbatches):
input_obj = input_objs.pop(0)
output_obj = output_objs.pop(0)
@@ -360,6 +376,9 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad)
self.send_backward(input_obj_grad)
if not self.forward_only:
assert all(len(v) == 0 for v in input_objs) and all(len(v) == 0 for v in output_objs)
if outputs is not None:
if isinstance(model, ModelWrapper):
model = model.unwrap()