mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 01:28:31 +00:00
[Pipeline Inference] Sync pipeline inference branch to main (#4820)
* [pipeline inference] pipeline inference (#4492) * add pp stage manager as circle stage * fix a bug when create process group * add ppinfer basic framework * add micro batch manager and support kvcache-pp gpt2 fwd * add generate schedule * use mb size to control mb number * support generate with kv cache * add output, remove unused code * add test * reuse shardformer to build model * refactor some code and use the same attribute name of hf * fix review and add test for generation * remove unused file * fix CI * add cache clear * fix code error * fix typo * [Pipeline inference] Modify to tieweight (#4599) * add pp stage manager as circle stage * fix a bug when create process group * add ppinfer basic framework * add micro batch manager and support kvcache-pp gpt2 fwd * add generate schedule * use mb size to control mb number * support generate with kv cache * add output, remove unused code * add test * reuse shardformer to build model * refactor some code and use the same attribute name of hf * fix review and add test for generation * remove unused file * modify the way of saving newtokens * modify to tieweight * modify test * remove unused file * solve review * add docstring * [Pipeline inference] support llama pipeline inference (#4647) * support llama pipeline inference * remove tie weight operation * [pipeline inference] Fix the blocking of communication when ppsize is 2 (#4708) * add benchmark verbose * fix export tokens * fix benchmark verbose * add P2POp style to do p2p communication * modify schedule as p2p type when ppsize is 2 * remove unused code and add docstring * [Pipeline inference] Refactor code, add docsting, fix bug (#4790) * add benchmark script * update argparse * fix fp16 load * refactor code style * add docstring * polish code * fix test bug * [Pipeline inference] Add pipeline inference docs (#4817) * add readme doc * add a ico * Add performance * update table of contents * refactor code (#4873)
This commit is contained in:
@@ -160,6 +160,86 @@ def _recv_object(src: int, dst: int, group: ProcessGroup) -> Any:
|
||||
return object_list[0]
|
||||
|
||||
|
||||
def _p2p_comm(
|
||||
tensor_send_next: torch.Tensor,
|
||||
recv_prev: bool,
|
||||
peer: int,
|
||||
group: ProcessGroup,
|
||||
comm_dtype: torch.dtype = torch.float16,
|
||||
):
|
||||
"""
|
||||
Send and recv tensor using P2P communication, used when pipeline size is 2 to solve the race communication.
|
||||
|
||||
Agrs:
|
||||
tensor_send_next (torch.Tensor): tensor to be sent to next stage
|
||||
recv_prev (bool): whether to receive tensor from previous stage
|
||||
peer (int): rank of the peer
|
||||
group (ProcessGroup): process group
|
||||
comm_dtype (torch.dtype): dtype of the tensor to be sent
|
||||
|
||||
Returns:
|
||||
torch.Tensor: tensor received from previous stage
|
||||
"""
|
||||
# send and recv shape
|
||||
send_next_shape = None
|
||||
recv_prev_shape = None
|
||||
|
||||
if tensor_send_next is not None:
|
||||
send_next_shape = torch.tensor(tensor_send_next.size(), device=torch.cuda.current_device(), dtype=torch.int64)
|
||||
if recv_prev:
|
||||
recv_prev_shape = torch.empty((3), device=torch.cuda.current_device(), dtype=torch.int64)
|
||||
|
||||
ops = []
|
||||
if send_next_shape is not None:
|
||||
send_next_op = dist.P2POp(dist.isend, send_next_shape, peer=peer, group=group)
|
||||
ops.append(send_next_op)
|
||||
if recv_prev_shape is not None:
|
||||
recv_prev_op = dist.P2POp(
|
||||
dist.irecv,
|
||||
recv_prev_shape,
|
||||
peer=peer,
|
||||
group=group,
|
||||
)
|
||||
ops.append(recv_prev_op)
|
||||
|
||||
if len(ops) > 0:
|
||||
reqs = dist.batch_isend_irecv(ops)
|
||||
for req in reqs:
|
||||
req.wait()
|
||||
|
||||
if recv_prev_shape is not None:
|
||||
recv_prev_shape = recv_prev_shape.tolist()
|
||||
|
||||
# send and recv data
|
||||
tensor_recv_prev = None
|
||||
if recv_prev:
|
||||
tensor_recv_prev = torch.empty(recv_prev_shape, device=torch.cuda.current_device(), dtype=comm_dtype)
|
||||
|
||||
ops = []
|
||||
if tensor_send_next is not None:
|
||||
send_next_op = dist.P2POp(
|
||||
dist.isend,
|
||||
tensor_send_next,
|
||||
peer=peer,
|
||||
group=group,
|
||||
)
|
||||
ops.append(send_next_op)
|
||||
|
||||
if tensor_recv_prev is not None:
|
||||
recv_prev_op = dist.P2POp(
|
||||
dist.irecv,
|
||||
tensor_recv_prev,
|
||||
peer=peer,
|
||||
group=group,
|
||||
)
|
||||
ops.append(recv_prev_op)
|
||||
if len(ops) > 0:
|
||||
reqs = dist.batch_isend_irecv(ops)
|
||||
for req in reqs:
|
||||
req.wait()
|
||||
return tensor_recv_prev
|
||||
|
||||
|
||||
class PipelineP2PCommunication:
|
||||
def __init__(self, stage_manager: PipelineStageManager) -> None:
|
||||
self.stage_manager = stage_manager
|
||||
@@ -221,3 +301,17 @@ class PipelineP2PCommunication:
|
||||
prev_rank = self.stage_manager.get_prev_rank()
|
||||
cur_rank = self.stage_manager.get_rank()
|
||||
_send_object(input_object, cur_rank, prev_rank, self.stage_manager.get_p2p_process_group(cur_rank, prev_rank))
|
||||
|
||||
def p2p_communicate(self, output_object: Any, recv_pre: bool, peer: int = None, comm_dtype: torch.dtype = torch.float16) -> None:
|
||||
"""
|
||||
Sends the input tensor to the next stage in pipeline, using `P2Pop` in torch.
|
||||
|
||||
Args:
|
||||
output_object (Any): Object to be sent.
|
||||
next_rank (int, optional): The rank of the recipient of the tensor.
|
||||
"""
|
||||
if peer is None:
|
||||
peer = self.stage_manager.get_next_rank()
|
||||
cur_rank = self.stage_manager.get_rank()
|
||||
recv_tensor = _p2p_comm(output_object, recv_pre, peer, self.stage_manager.get_p2p_process_group(cur_rank, peer), comm_dtype)
|
||||
return recv_tensor
|
||||
|
343
colossalai/pipeline/schedule/generate.py
Normal file
343
colossalai/pipeline/schedule/generate.py
Normal file
@@ -0,0 +1,343 @@
|
||||
import time
|
||||
from functools import partial
|
||||
from typing import Any, Iterable, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.cuda
|
||||
from torch.nn import Module
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
from colossalai.inference.pipeline.microbatch_manager import MicroBatchManager, Status
|
||||
from colossalai.pipeline.p2p import PipelineP2PCommunication
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
|
||||
from ._utils import get_batch_size, get_micro_batch, model_forward, to_device
|
||||
from .base import PipelineSchedule
|
||||
|
||||
|
||||
class ActionIntervalBuffer():
|
||||
"""
|
||||
The buffer to save the interval hidden states and new token for stage to use.
|
||||
|
||||
"""
|
||||
|
||||
def __int__(self):
|
||||
self.hidden_states = None
|
||||
self.new_token = None
|
||||
|
||||
def clear(self):
|
||||
self.hidden_states = None
|
||||
self.new_token = None
|
||||
|
||||
|
||||
class GenerateSchedule(PipelineSchedule):
|
||||
"""
|
||||
GenerateSchedule is a class that handles the pipeline parallel inference.
|
||||
In our schedule, we place tie weight layer, embedding and lm_head in the same device to save space, so in
|
||||
this schedule, the out for each encoding progress is on rank0.
|
||||
|
||||
Args:
|
||||
stage_manager (`PipelineStageManager`): Pipeline stage manager.
|
||||
mb_manager (`MicroBatchManager`): Micro batch manager.
|
||||
verbose (bool): Whether to verbose the information of the pipeline.
|
||||
"""
|
||||
|
||||
def __init__(self, stage_manager: PipelineStageManager, mb_manager: MicroBatchManager, verbose: bool) -> None:
|
||||
super().__init__(stage_manager)
|
||||
self.comm = PipelineP2PCommunication(stage_manager)
|
||||
self.mb_manager = mb_manager
|
||||
self.microbatch_size = mb_manager.micro_batch_size
|
||||
self.batch: Optional[Any] = None
|
||||
self.batch_size: Optional[int] = None
|
||||
self.microbatch_offset: Optional[int] = None
|
||||
self.num_microbatches: Optional[int] = None
|
||||
self.action_interval_buffer = ActionIntervalBuffer()
|
||||
self.verbose = verbose
|
||||
self.timestamps = None
|
||||
self.comm_dtype = None
|
||||
|
||||
def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None:
|
||||
"""Load a batch from data iterator.
|
||||
|
||||
Args:
|
||||
data_iter (Iterable): Data iterator.
|
||||
device (Optional[torch.device], optional): Target device. Defaults to None.
|
||||
"""
|
||||
batch = next(data_iter)
|
||||
if device is not None:
|
||||
batch = tree_map(partial(to_device, device=device), batch)
|
||||
self.batch = batch
|
||||
self.batch_size = get_batch_size(batch)
|
||||
self.microbatch_offset = 0
|
||||
assert self.batch_size % self.microbatch_size == 0, \
|
||||
f"Batch size should divided by the number of microbatches, {self.batch_size}, {self.num_microbatches}"
|
||||
self.num_microbatches = self.batch_size // self.microbatch_size
|
||||
self.round = self.num_microbatches // self.stage_manager.num_stages
|
||||
|
||||
def load_micro_batch(self) -> Any:
|
||||
"""Load a micro batch from the current batch.
|
||||
|
||||
Returns:
|
||||
Any: Micro batch.
|
||||
"""
|
||||
micro_batch = get_micro_batch(self.batch, self.microbatch_offset, self.microbatch_size)
|
||||
self.microbatch_offset += self.microbatch_size
|
||||
return tree_map(partial(to_device, device=get_current_device()), micro_batch)
|
||||
|
||||
def _prepare_inputs_for_interval_stage(self):
|
||||
'''
|
||||
Prepare inputs for interval stage, for all the interval stage, the inputs is just the past_key_values
|
||||
|
||||
Returns:
|
||||
dict: inputs for interval stage, `{'past_key_values': torch.Tensor}` or `None`
|
||||
'''
|
||||
model_inputs = {
|
||||
'past_key_values': self.mb_manager.cur_kv_cache
|
||||
} if self.mb_manager.cur_kv_cache is not None else None
|
||||
return model_inputs
|
||||
|
||||
def _prepare_inputs_for_new_token(self, new_token: torch.Tensor):
|
||||
'''
|
||||
Prepare inputs for new token, the inputs is a dict with `input_ids`, `attention_mask` and `past_key_values`
|
||||
`input_ids` is the new token, `attention_mask` is the previous mask add `1` in the end,
|
||||
`past_key_values` is the past_key_values save in the micro batch manager
|
||||
|
||||
Returns:
|
||||
dict: inputs for new token, `{'input_ids': torch.Tensor, 'attention_mask': torch.Tensor, 'past_key_values': torch.Tensor}`
|
||||
'''
|
||||
new_mask = self.mb_manager.cur_descrption.attn_mask
|
||||
past_key_values = self.mb_manager.cur_descrption.kv_cache
|
||||
|
||||
return dict(input_ids=new_token, attention_mask=new_mask, past_key_values=past_key_values)
|
||||
|
||||
def _get_token_id(self, hidden_state: torch.Tensor) -> torch.Tensor:
|
||||
last_hidden_state = hidden_state[:, -1]
|
||||
input_ids = torch.argmax(last_hidden_state, dim=-1).unsqueeze(1)
|
||||
return input_ids
|
||||
|
||||
def _recv_pre_stage(self) -> Any:
|
||||
'''
|
||||
Receive the output from previous stage
|
||||
|
||||
Returns:
|
||||
Any: The output from previous stage
|
||||
'''
|
||||
if self.stage_manager.num_stages == 2:
|
||||
return self.comm.p2p_recv()
|
||||
return self.comm.recv_forward()
|
||||
|
||||
def _load_stage_action(self, model: Module) -> None:
|
||||
"""
|
||||
In this action, 1.load micro_batch 2.do the forward 3.step to update
|
||||
"""
|
||||
inputs_dict = self.load_micro_batch()
|
||||
if self.verbose and self.stage_manager.is_first_stage():
|
||||
torch.cuda.synchronize()
|
||||
self.timestamps[self.mb_manager.idx].append(time.time())
|
||||
output_dict = model_forward(model, inputs_dict, None)
|
||||
|
||||
self.mb_manager.step(inputs_dict, output_dict, None)
|
||||
self.action_interval_buffer.hidden_states = output_dict['hidden_states']
|
||||
|
||||
def _gen_token_action(self, model: Module):
|
||||
"""
|
||||
In this action, 1.do the forward with hidden_states to generate new tokens 2.step to update
|
||||
"""
|
||||
hidden_states = self.action_interval_buffer.hidden_states
|
||||
assert hidden_states is not None, "When first stage in GENERATE phase, the hidden states should not be None"
|
||||
hidden_states = {'hidden_states': hidden_states}
|
||||
logits = model_forward(model, None, hidden_states)
|
||||
if self.verbose and self.stage_manager.is_first_stage():
|
||||
torch.cuda.synchronize()
|
||||
self.timestamps[self.mb_manager.idx].append(time.time())
|
||||
assert 'logits' in logits, f"When first stage in GENERATE phase, the ouput should have attribute `logits`, but has {logits.keys()}"
|
||||
new_token = self._get_token_id(logits['logits'])
|
||||
|
||||
self.mb_manager.step(None, None, new_token)
|
||||
self.action_interval_buffer.new_token = new_token
|
||||
self.action_interval_buffer.hidden_states = None
|
||||
|
||||
def _head_encoding_action(self, model: Module):
|
||||
"""
|
||||
In this action, 1.prepare inputs for encoding for first stage. 2.do the forward to get hidden states 3.step to update
|
||||
"""
|
||||
new_token = self.action_interval_buffer.new_token
|
||||
assert new_token is not None, "When first stage in GENERATE phase, the new token should not be None"
|
||||
inputs_dict = self._prepare_inputs_for_new_token(new_token)
|
||||
output_dict = model_forward(model, inputs_dict, None)
|
||||
|
||||
self.mb_manager.step(inputs_dict, output_dict, None)
|
||||
self.action_interval_buffer.hidden_states = output_dict['hidden_states']
|
||||
|
||||
def _body_encoding_action(self, model: Module):
|
||||
hidden_states = self.action_interval_buffer.hidden_states
|
||||
assert hidden_states is not None, "When not first stage, the hidden states should not be None"
|
||||
inputs_dict = self._prepare_inputs_for_interval_stage()
|
||||
hidden_states = {'hidden_states': hidden_states}
|
||||
output_dict = model_forward(model, inputs_dict, hidden_states)
|
||||
|
||||
self.mb_manager.step(inputs_dict, output_dict, None)
|
||||
self.action_interval_buffer.hidden_states = output_dict['hidden_states']
|
||||
|
||||
def _comm_action(self, recv_pre: bool) -> torch.Tensor:
|
||||
"""
|
||||
In this action, 1.receive the hidden_states from previous stage 2.send the hidden_states to next stage
|
||||
"""
|
||||
hidden_states = self.action_interval_buffer.hidden_states
|
||||
ret = self.comm.p2p_communicate(hidden_states, recv_pre, comm_dtype=self.comm_dtype)
|
||||
|
||||
self.action_interval_buffer.hidden_states = ret
|
||||
|
||||
def _gen_action(self, model: Module):
|
||||
"""
|
||||
In p2p step method, we use `P2POp` asynchronous communication method, so the communication need to be done
|
||||
at the begin of each microbatch, it's a more clear way to use an action list to do so. In this function, it will
|
||||
generate a sequence action for current state, and do the action one by one.
|
||||
|
||||
Args:
|
||||
model (Module): Model to be run.
|
||||
|
||||
Returns:
|
||||
List[Callable]: A list of action, each action is a callable function, and it will be called in order.
|
||||
"""
|
||||
actions = []
|
||||
if self.stage_manager.is_first_stage():
|
||||
if self.mb_manager.cur_state is Status.PREFILL:
|
||||
actions.append(partial(self._comm_action, False))
|
||||
actions.append(partial(self._load_stage_action, model))
|
||||
elif self.stage_manager.is_first_stage() and self.mb_manager.cur_state is Status.GENERATE:
|
||||
actions.append(partial(self._comm_action, True))
|
||||
actions.append(partial(self._gen_token_action, model))
|
||||
actions.append(partial(self._head_encoding_action, model))
|
||||
elif self.stage_manager.is_first_stage() and self.mb_manager.cur_state is Status.COOLDOWN:
|
||||
actions.append(partial(self._comm_action, True))
|
||||
actions.append(partial(self._gen_token_action, model))
|
||||
# other stage
|
||||
else:
|
||||
actions.append(partial(self._comm_action, True))
|
||||
actions.append(partial(self._body_encoding_action, model))
|
||||
|
||||
return actions
|
||||
|
||||
def generate_step(self, model: Module, data_iter: Iterable) -> Union[torch.Tensor, dict]:
|
||||
if self.stage_manager.num_stages == 2:
|
||||
return self.generate_step_p2p(model, data_iter)
|
||||
else:
|
||||
return self.generate_step_broadcast(model, data_iter)
|
||||
|
||||
@torch.no_grad()
|
||||
def generate_step_p2p(self, model: Module, data_iter: Iterable) -> Union[torch.Tensor, dict]:
|
||||
"""
|
||||
Forward one step of the pipeline, when pipeline size is 2, the schedule is a circle, broadcast communication will be
|
||||
blocked, so we use `P2POp` asynchronous communication method.
|
||||
|
||||
Args:
|
||||
model (Module): Model to be run.
|
||||
data_iter (Iterable): Data iterator.
|
||||
|
||||
Returns:
|
||||
Union[torch.Tensor, dict]: The intermediate output (dict) of the current stage. If it is the last stage, the output is the loss (Tensor).
|
||||
"""
|
||||
output_sequence = []
|
||||
self.load_batch(data_iter)
|
||||
model.eval()
|
||||
self.comm_dtype = model.dtype
|
||||
|
||||
whole_timestamp = []
|
||||
|
||||
#run by round
|
||||
for _ in range(self.round):
|
||||
self.timestamps = [[] for _ in range(self.stage_manager.num_stages)
|
||||
] if self.verbose and self.stage_manager.is_first_stage() else None
|
||||
self.action_interval_buffer.clear()
|
||||
while self.mb_manager.is_micro_batch_done() is False:
|
||||
actions = self._gen_action(model)
|
||||
for action in actions:
|
||||
action()
|
||||
self.mb_manager.next()
|
||||
# All microbatch in current round is DONE
|
||||
if self.stage_manager.is_first_stage():
|
||||
output_sequence.extend(self.mb_manager.export_new_tokens())
|
||||
else:
|
||||
self._comm_action(False)
|
||||
self.mb_manager.clear()
|
||||
if self.verbose and self.stage_manager.is_first_stage():
|
||||
whole_timestamp.extend(self.timestamps)
|
||||
|
||||
return output_sequence, whole_timestamp
|
||||
|
||||
@torch.no_grad()
|
||||
def generate_step_broadcast(self, model: Module, data_iter: Iterable) -> Union[torch.Tensor, dict]:
|
||||
"""
|
||||
Forward one step of the pipeline
|
||||
|
||||
Args:
|
||||
model (Module): Model to be run.
|
||||
data_iter (Iterable): Data iterator.
|
||||
|
||||
Returns:
|
||||
Union[torch.Tensor, dict]: The intermediate output (dict) of the current stage. If it is the last stage, the output is the loss (Tensor).
|
||||
"""
|
||||
output_sequence = []
|
||||
self.load_batch(data_iter)
|
||||
model.eval()
|
||||
|
||||
whole_timestamp = []
|
||||
# run by round
|
||||
for _ in range(self.round):
|
||||
self.timestamps = [[] for _ in range(self.stage_manager.num_stages)
|
||||
] if self.verbose and self.stage_manager.is_first_stage() else None
|
||||
while self.mb_manager.is_micro_batch_done() is False:
|
||||
inputs_dict = None
|
||||
new_token = None
|
||||
output_dict = None
|
||||
|
||||
# First stage and in PREFILL phase, just load the inputs
|
||||
if self.stage_manager.is_first_stage() and self.mb_manager.cur_state is Status.PREFILL:
|
||||
inputs_dict = self.load_micro_batch()
|
||||
if self.verbose and self.stage_manager.is_first_stage():
|
||||
torch.cuda.synchronize()
|
||||
self.timestamps[self.mb_manager.idx].append(time.time())
|
||||
output_dict = model_forward(model, inputs_dict, None)
|
||||
self.mb_manager.step(inputs_dict, output_dict, None)
|
||||
# In GENERATE phase
|
||||
else:
|
||||
# Get hidden_states from previous stage
|
||||
hidden_states = self.comm.recv_forward()
|
||||
if self.stage_manager.is_first_stage():
|
||||
# First just generate a new token
|
||||
assert hidden_states is not None, "When first stage in GENERATE phase, the hidden states should not be None"
|
||||
logits = model_forward(model, None, hidden_states)
|
||||
if self.verbose and self.stage_manager.is_first_stage():
|
||||
torch.cuda.synchronize()
|
||||
self.timestamps[self.mb_manager.idx].append(time.time())
|
||||
assert 'logits' in logits, f"When first stage in GENERATE phase, the ouput should have attribute `logits`, but has {logits.keys()}"
|
||||
new_token = self._get_token_id(logits['logits'])
|
||||
self.mb_manager.step(None, None, new_token)
|
||||
# If the current micro batch is not DONE, go through blocks
|
||||
if self.mb_manager.cur_state in (Status.GENERATE, Status.COOLDOWN):
|
||||
inputs_dict = self._prepare_inputs_for_new_token(new_token)
|
||||
output_dict = model_forward(model, inputs_dict, None)
|
||||
self.mb_manager.step(inputs_dict, output_dict, None)
|
||||
else:
|
||||
assert hidden_states is not None, "When not first stage, the hidden states should not be None"
|
||||
inputs_dict = self._prepare_inputs_for_interval_stage()
|
||||
output_dict = model_forward(model, inputs_dict, hidden_states)
|
||||
self.mb_manager.step(inputs_dict, output_dict, None)
|
||||
|
||||
# Current microbatch is not DONE, send hidden_state to next stage
|
||||
if not self.stage_manager.is_first_stage() or self.mb_manager.cur_state in (Status.GENERATE,
|
||||
Status.COOLDOWN):
|
||||
self.comm.send_forward({'hidden_states': output_dict['hidden_states']})
|
||||
|
||||
self.mb_manager.next()
|
||||
|
||||
# All microbatch in current round is DONE
|
||||
if self.stage_manager.is_first_stage():
|
||||
output_sequence.extend(self.mb_manager.export_new_tokens())
|
||||
self.mb_manager.clear()
|
||||
if self.verbose and self.stage_manager.is_first_stage():
|
||||
whole_timestamp.extend(self.timestamps)
|
||||
|
||||
return output_sequence, whole_timestamp
|
@@ -12,6 +12,7 @@ class PipelineStageManager:
|
||||
Args:
|
||||
pg_mesh (ProcessGroupMesh): Process group mesh.
|
||||
pipeline_axis (int): The axis along which the pipeline is constructed.
|
||||
is_virtual (bool): Whether to use circle p2p communication, it will make the first and last stage communicate with each other.
|
||||
|
||||
Attributes:
|
||||
num_stages (int): Number of stages in the pipeline.
|
||||
@@ -24,6 +25,7 @@ class PipelineStageManager:
|
||||
self.prev_rank: Optional[Tuple[int, ...]] = None
|
||||
self.next_rank: Optional[Tuple[int, ...]] = None
|
||||
self.p2p_groups: Dict[Tuple[int, int], ProcessGroup] = {}
|
||||
|
||||
# init prev and next coord
|
||||
coord = self.pg_mesh.coordinate()
|
||||
# the prev rank of rank0 is the last rank
|
||||
|
Reference in New Issue
Block a user