mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 19:13:01 +00:00
[pipeline]: fix p2p comm, add metadata cache and support llama interleaved pp (#5134)
* test: add more p2p tests * fix: remove send_forward_recv_forward as p2p op list need to use the same group * fix: make send and receive atomic * feat: update P2PComm fn * feat: add metadata cache in 1f1b * feat: add metadata cache in interleaved pp * feat: modify is_xx_stage fn * revert: add _broadcast_object_list * feat: add interleaved pp in llama policy * feat: set NCCL_BUFFSIZE in HybridParallelPlugin
This commit is contained in:
@@ -7,7 +7,7 @@ from torch.nn import Module, ModuleList
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
from colossalai.interface import OptimizerWrapper
|
||||
from colossalai.pipeline.p2p import PipelineP2PCommunication
|
||||
from colossalai.pipeline.p2p import PipelineP2PCommunication, create_fast_send_metadata
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.utils.device import get_current_device
|
||||
|
||||
@@ -27,6 +27,7 @@ class InterleavedSchedule(PipelineSchedule):
|
||||
assert (
|
||||
num_microbatch is not None or microbatch_size is not None
|
||||
), "Either num_microbatch or microbatch_size should be provided"
|
||||
|
||||
self.comm = PipelineP2PCommunication(stage_manager)
|
||||
self.num_microbatch = num_microbatch
|
||||
self.microbatch_size = microbatch_size
|
||||
@@ -34,8 +35,15 @@ class InterleavedSchedule(PipelineSchedule):
|
||||
|
||||
self.batch: Any
|
||||
self.batch_size: int
|
||||
self.last_batch_size: Optional[int] = None
|
||||
self.microbatch_offset: List[int]
|
||||
|
||||
# P2PMeta cache
|
||||
self.send_metadata_forward = True
|
||||
self.send_metadata_backward = True
|
||||
self.metadata_recv_forward = None
|
||||
self.metadata_recv_backward = None
|
||||
|
||||
def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None:
|
||||
"""Load a batch from data iterator.
|
||||
|
||||
@@ -48,6 +56,11 @@ class InterleavedSchedule(PipelineSchedule):
|
||||
batch = tree_map(partial(to_device, device=device), batch)
|
||||
self.batch = batch
|
||||
self.batch_size = get_batch_size(batch)
|
||||
if self.last_batch_size is None:
|
||||
self.last_batch_size = self.batch_size
|
||||
else:
|
||||
assert self.forward_only or self.last_batch_size == self.batch_size
|
||||
# TODO: support arbitrary batch size when forward_only=True
|
||||
self.microbatch_offset = [0 for _ in range(self.num_model_chunks)]
|
||||
if self.num_microbatch is not None:
|
||||
assert self.batch_size % self.num_microbatch == 0, "Batch size should divided by the number of microbatch"
|
||||
@@ -106,12 +119,13 @@ class InterleavedSchedule(PipelineSchedule):
|
||||
Returns:
|
||||
Any: The input tensor or input tensor list.
|
||||
"""
|
||||
if self.stage_manager.is_first_stage(model_chunk_id):
|
||||
input_tensor = None
|
||||
else:
|
||||
input_tensor = self.comm.recv_forward(prev_rank)
|
||||
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
|
||||
if not self.stage_manager.is_first_stage():
|
||||
input_tensor = self.comm.recv_forward(prev_rank, metadata_recv=self.metadata_recv_forward)
|
||||
if self.metadata_recv_forward is None:
|
||||
self.metadata_recv_forward = create_fast_send_metadata(input_tensor)
|
||||
|
||||
return input_tensor
|
||||
return input_tensor
|
||||
|
||||
def recv_backward(self, model_chunk_id: int, next_rank: int = None) -> Any:
|
||||
"""Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage.
|
||||
@@ -124,14 +138,15 @@ class InterleavedSchedule(PipelineSchedule):
|
||||
Returns:
|
||||
Any: The input gradient tensor or gradient tensor list.
|
||||
"""
|
||||
if self.stage_manager.is_last_stage(model_chunk_id):
|
||||
output_tensor_grad = None
|
||||
else:
|
||||
output_tensor_grad = self.comm.recv_backward(next_rank)
|
||||
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
|
||||
if not self.stage_manager.is_last_stage():
|
||||
output_tensor_grad = self.comm.recv_backward(next_rank, metadata_recv=self.metadata_recv_backward)
|
||||
if self.metadata_recv_backward is None:
|
||||
self.metadata_recv_backward = create_fast_send_metadata(output_tensor_grad)
|
||||
|
||||
return output_tensor_grad
|
||||
return output_tensor_grad
|
||||
|
||||
def send_forward(self, model_chunk_id, output_object: Any, next_rank: int = None) -> None:
|
||||
def send_forward(self, model_chunk_id: int, output_object: Any, next_rank: int = None) -> None:
|
||||
"""Sends the input tensor to the next stage in pipeline.
|
||||
For interleaved 1F1B.
|
||||
|
||||
@@ -140,10 +155,12 @@ class InterleavedSchedule(PipelineSchedule):
|
||||
output_object (Any): Object to be sent.
|
||||
next_rank (int, optional): The rank of the recipient of the tensor.
|
||||
"""
|
||||
if not self.stage_manager.is_last_stage(model_chunk_id):
|
||||
self.comm.send_forward(output_object, next_rank)
|
||||
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
|
||||
if not self.stage_manager.is_last_stage():
|
||||
self.comm.send_forward(output_object, next_rank, send_metadata=self.send_metadata_forward)
|
||||
self.send_metadata_forward = False
|
||||
|
||||
def send_backward(self, model_chunk_id, input_object: Any, prev_rank: int = None) -> None:
|
||||
def send_backward(self, model_chunk_id: int, input_object: Any, prev_rank: int = None) -> None:
|
||||
"""Sends the gradient tensor to the previous stage in pipeline.
|
||||
For interleaved 1F1B.
|
||||
|
||||
@@ -152,8 +169,44 @@ class InterleavedSchedule(PipelineSchedule):
|
||||
input_object (Any): Object to be sent.
|
||||
prev_rank (int, optional): The rank of the recipient of the tensor
|
||||
"""
|
||||
if not self.stage_manager.is_first_stage(model_chunk_id):
|
||||
self.comm.send_backward(input_object, prev_rank)
|
||||
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
|
||||
if not self.stage_manager.is_first_stage():
|
||||
self.comm.send_backward(input_object, prev_rank, send_metadata=self.send_metadata_backward)
|
||||
self.send_metadata_backward = False
|
||||
|
||||
def send_forward_recv_backward(
|
||||
self, model_chunk_id: int, output_object: Any, next_rank: Optional[int] = None
|
||||
) -> Any:
|
||||
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
|
||||
if not self.stage_manager.is_last_stage():
|
||||
output_tensor_grad = self.comm.send_forward_recv_backward(
|
||||
output_object,
|
||||
next_rank,
|
||||
send_metadata=self.send_metadata_forward,
|
||||
metadata_recv=self.metadata_recv_backward,
|
||||
)
|
||||
self.send_metadata_forward = False
|
||||
if self.metadata_recv_backward is None:
|
||||
self.metadata_recv_backward = create_fast_send_metadata(output_tensor_grad)
|
||||
|
||||
return output_tensor_grad
|
||||
|
||||
def send_backward_recv_forward(
|
||||
self, model_chunk_id: int, output_object: Any, prev_rank: Optional[int] = None
|
||||
) -> Any:
|
||||
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
|
||||
if not self.stage_manager.is_first_stage():
|
||||
input_tensor = self.comm.send_backward_recv_forward(
|
||||
output_object,
|
||||
prev_rank,
|
||||
send_metadata=self.send_metadata_backward,
|
||||
metadata_recv=self.metadata_recv_forward,
|
||||
)
|
||||
self.send_metadata_backward = False
|
||||
if self.metadata_recv_forward is None:
|
||||
self.metadata_recv_forward = create_fast_send_metadata(input_tensor)
|
||||
|
||||
return input_tensor
|
||||
|
||||
def forward_step(
|
||||
self,
|
||||
@@ -180,25 +233,24 @@ class InterleavedSchedule(PipelineSchedule):
|
||||
# for the first stage, input_obj is None
|
||||
# for the non-first stage, input_obj is the output of the previous stage and it's must be a dict
|
||||
|
||||
self.stage_manager.model_chunk_id = model_chunk_id
|
||||
if isinstance(model_chunk, ModuleList):
|
||||
output_obj = model_forward(model_chunk[model_chunk_id], micro_batch, input_obj)
|
||||
else:
|
||||
# NOTE: in shardformer, each device still has the entire model, so we need to use relevant stage layers
|
||||
internal_inputs = {} if input_obj is None else input_obj
|
||||
internal_inputs["stage_index"] = self.stage_manager.stage_indices[model_chunk_id]
|
||||
output_obj = model_forward(model_chunk, micro_batch, internal_inputs)
|
||||
self.stage_manager.model_chunk_id = None
|
||||
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
|
||||
if isinstance(model_chunk, ModuleList):
|
||||
output_obj = model_forward(model_chunk[model_chunk_id], micro_batch, input_obj)
|
||||
else:
|
||||
# NOTE: in shardformer, each device still has the entire model, so we need to use relevant stage layers
|
||||
internal_inputs = {} if input_obj is None else input_obj
|
||||
internal_inputs["stage_index"] = self.stage_manager.stage_indices[model_chunk_id]
|
||||
output_obj = model_forward(model_chunk, micro_batch, internal_inputs)
|
||||
|
||||
if self.stage_manager.is_last_stage(model_chunk_id):
|
||||
loss = criterion(output_obj, micro_batch) / self.num_microbatch
|
||||
if accum_loss is not None:
|
||||
accum_loss.add_(loss.detach())
|
||||
if outputs is not None:
|
||||
outputs.append(tree_map(detach, output_obj))
|
||||
return loss
|
||||
else:
|
||||
return output_obj
|
||||
if self.stage_manager.is_last_stage():
|
||||
loss = criterion(output_obj, micro_batch) / self.num_microbatch
|
||||
if accum_loss is not None:
|
||||
accum_loss.add_(loss.detach())
|
||||
if outputs is not None:
|
||||
outputs.append(tree_map(detach, output_obj))
|
||||
return loss
|
||||
else:
|
||||
return output_obj
|
||||
|
||||
def backward_step(
|
||||
self,
|
||||
@@ -267,15 +319,14 @@ class InterleavedSchedule(PipelineSchedule):
|
||||
Returns:
|
||||
dict: A dict with keys: 'loss' and 'outputs'.
|
||||
"""
|
||||
# TODO: handle arbitrary batch size when forward_only == True
|
||||
forward_only = not torch.is_grad_enabled()
|
||||
self.forward_only = not torch.is_grad_enabled()
|
||||
if optimizer is None:
|
||||
assert forward_only, "Optimizer should be passed when doing backward."
|
||||
assert self.forward_only, "Optimizer should be passed when doing backward."
|
||||
|
||||
self.load_batch(data_iter)
|
||||
|
||||
num_microbatch = self.num_microbatch * self.num_model_chunks
|
||||
if forward_only:
|
||||
if self.forward_only:
|
||||
num_warmup_microbatch = num_microbatch
|
||||
else:
|
||||
num_warmup_microbatch = (self.stage_manager.num_stages - self.stage_manager.stage - 1) * 2
|
||||
@@ -288,43 +339,29 @@ class InterleavedSchedule(PipelineSchedule):
|
||||
input_objs = None
|
||||
output_objs = None
|
||||
|
||||
if not forward_only:
|
||||
if not self.forward_only:
|
||||
input_objs = [[] for _ in range(self.num_model_chunks)]
|
||||
output_objs = [[] for _ in range(self.num_model_chunks)]
|
||||
|
||||
outputs = [] if return_outputs and self.stage_manager.is_last_stage(-1) else None
|
||||
outputs = [] if return_outputs and self.stage_manager.is_last_stage(ignore_chunk=True) else None
|
||||
|
||||
if return_loss and self.stage_manager.is_last_stage(-1):
|
||||
accum_loss = None
|
||||
if return_loss and self.stage_manager.is_last_stage(ignore_chunk=True):
|
||||
accum_loss = torch.zeros(1, device=get_current_device())
|
||||
else:
|
||||
accum_loss = None
|
||||
|
||||
# for ranks except the first one, get into recv state
|
||||
input_obj = self.recv_forward(0)
|
||||
|
||||
# Run warmup forward passes.
|
||||
for i in range(num_warmup_microbatch):
|
||||
model_chunk_id = self.get_model_chunk_id(i, is_forward=True)
|
||||
# recv first on first rank to avoid sending or receiving at the same time
|
||||
if self.stage_manager.is_first_stage(-1):
|
||||
input_obj = self.recv_forward(model_chunk_id)
|
||||
output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)
|
||||
self.send_forward(model_chunk_id, output_obj)
|
||||
if not forward_only:
|
||||
input_objs[model_chunk_id].append(input_obj)
|
||||
output_objs[model_chunk_id].append(output_obj)
|
||||
else:
|
||||
output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)
|
||||
if not forward_only:
|
||||
input_objs[model_chunk_id].append(input_obj)
|
||||
output_objs[model_chunk_id].append(output_obj)
|
||||
self.send_forward(model_chunk_id, output_obj)
|
||||
input_obj = self.recv_forward(model_chunk_id)
|
||||
output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)
|
||||
if not self.forward_only:
|
||||
input_objs[model_chunk_id].append(input_obj)
|
||||
output_objs[model_chunk_id].append(output_obj)
|
||||
self.send_forward(model_chunk_id, output_obj)
|
||||
|
||||
if num_microbatch_remaining == 0 and i + 1 == num_warmup_microbatch:
|
||||
break
|
||||
|
||||
model_chunk_id = self.get_model_chunk_id(i + 1, is_forward=True)
|
||||
input_obj = self.recv_forward(model_chunk_id)
|
||||
if num_microbatch_remaining > 0:
|
||||
model_chunk_id = self.get_model_chunk_id(num_warmup_microbatch, is_forward=True)
|
||||
input_obj = self.recv_forward(model_chunk_id)
|
||||
|
||||
# Run 1F1B in steady state.
|
||||
for i in range(num_microbatch_remaining):
|
||||
@@ -332,11 +369,11 @@ class InterleavedSchedule(PipelineSchedule):
|
||||
last_iteration = i == num_microbatch_remaining - 1
|
||||
|
||||
output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)
|
||||
if forward_only:
|
||||
self.send_forward(model_chunk_id, output_obj)
|
||||
|
||||
if self.forward_only:
|
||||
if not last_iteration:
|
||||
input_obj = self.recv_forward(model_chunk_id)
|
||||
input_obj = self.send_forward_recv_backward(model_chunk_id, output_obj)
|
||||
else:
|
||||
self.send_forward(model_chunk_id, output_obj)
|
||||
|
||||
else:
|
||||
self.send_forward(model_chunk_id, output_obj)
|
||||
@@ -354,18 +391,14 @@ class InterleavedSchedule(PipelineSchedule):
|
||||
|
||||
# backward
|
||||
input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad)
|
||||
self.send_backward(model_chunk_id, input_obj_grad)
|
||||
|
||||
if last_iteration:
|
||||
input_obj = None
|
||||
else:
|
||||
if not last_iteration:
|
||||
model_chunk_id = self.get_model_chunk_id(i + num_warmup_microbatch + 1, is_forward=True)
|
||||
input_obj = self.recv_forward(model_chunk_id)
|
||||
|
||||
model_chunk_id = self.get_model_chunk_id(i, is_forward=False)
|
||||
self.send_backward(model_chunk_id, input_obj_grad)
|
||||
|
||||
# Run cooldown backward passes.
|
||||
if not forward_only:
|
||||
if not self.forward_only:
|
||||
for i in range(num_microbatch_remaining, num_microbatch):
|
||||
model_chunk_id = self.get_model_chunk_id(i, is_forward=False)
|
||||
input_obj = input_objs[model_chunk_id].pop(0)
|
||||
@@ -374,7 +407,7 @@ class InterleavedSchedule(PipelineSchedule):
|
||||
input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad)
|
||||
self.send_backward(model_chunk_id, input_obj_grad)
|
||||
|
||||
if not forward_only:
|
||||
if not self.forward_only:
|
||||
assert all(len(v) == 0 for v in input_objs) and all(len(v) == 0 for v in output_objs)
|
||||
|
||||
if outputs is not None:
|
||||
|
Reference in New Issue
Block a user