mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 12:30:42 +00:00
[pipeline]: fix p2p comm, add metadata cache and support llama interleaved pp (#5134)
* test: add more p2p tests * fix: remove send_forward_recv_forward as p2p op list need to use the same group * fix: make send and receive atomic * feat: update P2PComm fn * feat: add metadata cache in 1f1b * feat: add metadata cache in interleaved pp * feat: modify is_xx_stage fn * revert: add _broadcast_object_list * feat: add interleaved pp in llama policy * feat: set NCCL_BUFFSIZE in HybridParallelPlugin
This commit is contained in:
@@ -7,7 +7,7 @@ from torch.nn import Module
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||
from colossalai.pipeline.p2p import PipelineP2PCommunication
|
||||
from colossalai.pipeline.p2p import PipelineP2PCommunication, create_fast_send_metadata
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.utils.device import get_current_device
|
||||
|
||||
@@ -42,14 +42,22 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
||||
assert (
|
||||
num_microbatches is not None or microbatch_size is not None
|
||||
), "Either num_microbatches or microbatch_size should be provided"
|
||||
|
||||
self.comm = PipelineP2PCommunication(stage_manager)
|
||||
self.num_microbatches = num_microbatches
|
||||
self.microbatch_size = microbatch_size
|
||||
self.batch: Optional[Any] = None
|
||||
self.batch_size: Optional[int] = None
|
||||
self.last_batch_size: Optional[int] = None
|
||||
self.microbatch_offset: Optional[int] = None
|
||||
self._use_microbatch_size = num_microbatches is None
|
||||
|
||||
# P2PMeta cache
|
||||
self.send_metadata_forward = True
|
||||
self.send_metadata_backward = True
|
||||
self.metadata_recv_forward = None
|
||||
self.metadata_recv_backward = None
|
||||
|
||||
def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None:
|
||||
"""Load a batch from data iterator.
|
||||
|
||||
@@ -60,8 +68,14 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
||||
batch = next(data_iter)
|
||||
if device is not None:
|
||||
batch = tree_map(partial(to_device, device=device), batch)
|
||||
|
||||
self.batch = batch
|
||||
self.batch_size = get_batch_size(batch)
|
||||
if self.last_batch_size is None:
|
||||
self.last_batch_size = self.batch_size
|
||||
else:
|
||||
assert self.forward_only or self.last_batch_size == self.batch_size
|
||||
# TODO: support arbitrary batch size when forward_only=True
|
||||
self.microbatch_offset = 0
|
||||
if not self._use_microbatch_size:
|
||||
assert (
|
||||
@@ -92,12 +106,12 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
||||
Returns:
|
||||
Any: The input tensor or input tensor list.
|
||||
"""
|
||||
if self.stage_manager.is_first_stage():
|
||||
input_tensor = None
|
||||
else:
|
||||
input_tensor = self.comm.recv_forward(prev_rank)
|
||||
if not self.stage_manager.is_first_stage():
|
||||
input_tensor = self.comm.recv_forward(prev_rank, metadata_recv=self.metadata_recv_forward)
|
||||
if self.metadata_recv_forward is None:
|
||||
self.metadata_recv_forward = create_fast_send_metadata(input_tensor)
|
||||
|
||||
return input_tensor
|
||||
return input_tensor
|
||||
|
||||
def recv_backward(self, next_rank: int = None) -> Any:
|
||||
"""Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage.
|
||||
@@ -109,12 +123,12 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
||||
Returns:
|
||||
Any: The input gradient tensor or gradient tensor list.
|
||||
"""
|
||||
if self.stage_manager.is_last_stage():
|
||||
output_tensor_grad = None
|
||||
else:
|
||||
output_tensor_grad = self.comm.recv_backward(next_rank)
|
||||
if not self.stage_manager.is_last_stage():
|
||||
output_tensor_grad = self.comm.recv_backward(next_rank, metadata_recv=self.metadata_recv_backward)
|
||||
if self.metadata_recv_backward is None:
|
||||
self.metadata_recv_backward = create_fast_send_metadata(output_tensor_grad)
|
||||
|
||||
return output_tensor_grad
|
||||
return output_tensor_grad
|
||||
|
||||
def send_forward(self, output_object: Any, next_rank: int = None) -> None:
|
||||
"""Sends the input tensor to the next stage in pipeline.
|
||||
@@ -125,18 +139,8 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
||||
next_rank (int, optional): The rank of the recipient of the tensor.
|
||||
"""
|
||||
if not self.stage_manager.is_last_stage():
|
||||
self.comm.send_forward(output_object, next_rank)
|
||||
|
||||
def send_forward_recv_backward(self, output_object: Any, next_rank: int = None) -> Any:
|
||||
"""Sends the input tensor to the next stage and copy the gradient tensor from the next stage in pipeline.
|
||||
For 1F1B.
|
||||
|
||||
Args:
|
||||
output_object (Any): Object to be sent.
|
||||
next_rank (int, optional): The rank of the recipient of the tensor.
|
||||
"""
|
||||
if not self.stage_manager.is_last_stage():
|
||||
return self.comm.send_forward_recv_backward(output_object, next_rank)
|
||||
self.comm.send_forward(output_object, next_rank, send_metadata=self.send_metadata_forward)
|
||||
self.send_metadata_forward = False
|
||||
|
||||
def send_backward(self, input_object: Any, prev_rank: int = None) -> None:
|
||||
"""Sends the gradient tensor to the previous stage in pipeline.
|
||||
@@ -147,7 +151,29 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
||||
prev_rank (int, optional): The rank of the recipient of the tensor
|
||||
"""
|
||||
if not self.stage_manager.is_first_stage():
|
||||
self.comm.send_backward(input_object, prev_rank)
|
||||
self.comm.send_backward(input_object, prev_rank, send_metadata=self.send_metadata_backward)
|
||||
self.send_metadata_backward = False
|
||||
|
||||
def send_forward_recv_backward(self, output_object: Any, next_rank: int = None) -> Any:
|
||||
"""Sends the input tensor to the next stage and copy the gradient tensor from the next stage in pipeline.
|
||||
For 1F1B.
|
||||
|
||||
Args:
|
||||
output_object (Any): Object to be sent.
|
||||
next_rank (int, optional): The rank of the recipient of the tensor.
|
||||
"""
|
||||
if not self.stage_manager.is_last_stage():
|
||||
output_tensor_grad = self.comm.send_forward_recv_backward(
|
||||
output_object,
|
||||
next_rank,
|
||||
send_metadata=self.send_metadata_forward,
|
||||
metadata_recv=self.metadata_recv_backward,
|
||||
)
|
||||
self.send_metadata_forward = False
|
||||
if self.metadata_recv_backward is None:
|
||||
self.metadata_recv_backward = create_fast_send_metadata(output_tensor_grad)
|
||||
|
||||
return output_tensor_grad
|
||||
|
||||
def send_backward_recv_forward(self, output_object: Any, prev_rank: int = None) -> Any:
|
||||
"""Sends the gradient tensor to the previous stage and copy the input tensor from the previous stage in pipeline.
|
||||
@@ -158,23 +184,17 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
||||
prev_rank (int, optional): The rank of the recipient of the tensor.
|
||||
"""
|
||||
if not self.stage_manager.is_first_stage():
|
||||
return self.comm.send_backward_recv_forward(output_object, prev_rank)
|
||||
input_tensor = self.comm.send_backward_recv_forward(
|
||||
output_object,
|
||||
prev_rank,
|
||||
send_metadata=self.send_metadata_backward,
|
||||
metadata_recv=self.metadata_recv_forward,
|
||||
)
|
||||
self.send_metadata_backward = False
|
||||
if self.metadata_recv_forward is None:
|
||||
self.metadata_recv_forward = create_fast_send_metadata(input_tensor)
|
||||
|
||||
def send_forward_recv_forward(self, input_object: Any, prev_rank: int = None, next_rank: int = None) -> Any:
|
||||
"""Sends the input tensor to the next stage and copy the input tensor from the previous stage in pipeline.
|
||||
For 1F1B.
|
||||
|
||||
Args:
|
||||
input_object (Any): Object to be sent.
|
||||
prev_rank (int, optional): The previous rank of the recipient of the tensor.
|
||||
next_rank (int, optional): The next rank of the recipient of the tensor.
|
||||
"""
|
||||
if self.stage_manager.is_first_stage():
|
||||
return self.comm.send_forward(input_object, next_rank)
|
||||
elif self.stage_manager.is_last_stage():
|
||||
return self.comm.recv_forward(prev_rank)
|
||||
else:
|
||||
return self.comm.send_forward_recv_forward(input_object, prev_rank, next_rank)
|
||||
return input_tensor
|
||||
|
||||
def forward_step(
|
||||
self,
|
||||
@@ -276,9 +296,10 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
||||
Returns:
|
||||
dict: A dict with keys: 'loss' and 'outputs'.
|
||||
"""
|
||||
forward_only = not torch.is_grad_enabled()
|
||||
|
||||
self.forward_only = not torch.is_grad_enabled()
|
||||
if optimizer is None:
|
||||
assert forward_only, "Optimizer should be passed when doing backward."
|
||||
assert self.forward_only, "Optimizer should be passed when doing backward."
|
||||
|
||||
self.load_batch(data_iter)
|
||||
|
||||
@@ -291,25 +312,22 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
||||
input_objs = None
|
||||
output_objs = None
|
||||
|
||||
if not forward_only:
|
||||
if not self.forward_only:
|
||||
input_objs = []
|
||||
output_objs = []
|
||||
|
||||
outputs = [] if return_outputs and self.stage_manager.is_last_stage() else None
|
||||
accum_loss = None
|
||||
if return_loss and self.stage_manager.is_last_stage():
|
||||
accum_loss = torch.zeros(1, device=get_current_device())
|
||||
else:
|
||||
accum_loss = None
|
||||
outputs = [] if return_outputs and self.stage_manager.is_last_stage() else None
|
||||
|
||||
# Run warmup forward passes.
|
||||
for i in range(num_warmup_microbatches):
|
||||
input_obj = self.recv_forward()
|
||||
|
||||
output_obj = self.forward_step(model, input_obj, criterion, accum_loss, outputs)
|
||||
|
||||
self.send_forward(output_obj)
|
||||
|
||||
if not forward_only:
|
||||
if not self.forward_only:
|
||||
input_objs.append(input_obj)
|
||||
output_objs.append(output_obj)
|
||||
|
||||
@@ -324,16 +342,15 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
||||
last_iteration = i == (num_microbatches_remaining - 1)
|
||||
|
||||
output_obj = self.forward_step(model, input_obj, criterion, accum_loss, outputs)
|
||||
if forward_only:
|
||||
|
||||
if self.forward_only:
|
||||
self.send_forward(output_obj)
|
||||
|
||||
if not last_iteration:
|
||||
input_obj = self.recv_forward()
|
||||
else:
|
||||
# TODO adjust here
|
||||
self.send_forward(output_obj)
|
||||
output_obj_grad = self.recv_backward()
|
||||
|
||||
else:
|
||||
output_obj_grad = self.send_forward_recv_backward(output_obj)
|
||||
# Add input_obj and output_obj to end of list.
|
||||
input_objs.append(input_obj)
|
||||
output_objs.append(output_obj)
|
||||
@@ -345,13 +362,12 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
||||
input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad)
|
||||
|
||||
if last_iteration:
|
||||
input_obj = None
|
||||
self.send_backward(input_obj_grad)
|
||||
else:
|
||||
input_obj = self.recv_forward()
|
||||
self.send_backward(input_obj_grad)
|
||||
input_obj = self.send_backward_recv_forward(input_obj_grad)
|
||||
|
||||
# Run cooldown backward passes.
|
||||
if not forward_only:
|
||||
if not self.forward_only:
|
||||
for i in range(num_warmup_microbatches):
|
||||
input_obj = input_objs.pop(0)
|
||||
output_obj = output_objs.pop(0)
|
||||
@@ -360,6 +376,9 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
||||
input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad)
|
||||
self.send_backward(input_obj_grad)
|
||||
|
||||
if not self.forward_only:
|
||||
assert all(len(v) == 0 for v in input_objs) and all(len(v) == 0 for v in output_objs)
|
||||
|
||||
if outputs is not None:
|
||||
if isinstance(model, ModelWrapper):
|
||||
model = model.unwrap()
|
||||
|
Reference in New Issue
Block a user