mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 01:28:31 +00:00
[inference] Refactor inference architecture (#5057)
* [inference] support only TP (#4998) * support only tp * enable tp * add support for bloom (#5008) * [refactor] refactor gptq and smoothquant llama (#5012) * refactor gptq and smoothquant llama * fix import error * fix linear import torch-int * fix smoothquant llama import error * fix import accelerate error * fix bug * fix import smooth cuda * fix smoothcuda * [Inference Refactor] Merge chatglm2 with pp and tp (#5023) merge chatglm with pp and tp * [Refactor] remove useless inference code (#5022) * remove useless code * fix quant model * fix test import bug * mv original inference legacy * fix chatglm2 * [Refactor] refactor policy search and quant type controlling in inference (#5035) * [Refactor] refactor policy search and quant type controling in inference * [inference] update readme (#5051) * update readme * update readme * fix architecture * fix table * fix table * [inference] udpate example (#5053) * udpate example * fix run.sh * fix rebase bug * fix some errors * update readme * add some features * update interface * update readme * update benchmark * add requirements-infer --------- Co-authored-by: Bin Jia <45593998+FoolPlayer@users.noreply.github.com> Co-authored-by: Zhongkai Zhao <kanezz620@gmail.com>
This commit is contained in:
@@ -7,7 +7,7 @@ import torch.cuda
|
||||
from torch.nn import Module
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
from colossalai.inference.pipeline.microbatch_manager import MicroBatchManager, Status
|
||||
from colossalai.inference.engine.microbatch_manager import MicroBatchManager, Status
|
||||
from colossalai.pipeline.p2p import PipelineP2PCommunication
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
@@ -93,9 +93,7 @@ class GenerateSchedule(PipelineSchedule):
|
||||
Returns:
|
||||
dict: inputs for interval stage, `{'past_key_values': torch.Tensor}` or `None`
|
||||
"""
|
||||
model_inputs = {
|
||||
'infer_state': self.mb_manager.cur_descrption.infer_state
|
||||
}
|
||||
model_inputs = {"infer_state": self.mb_manager.cur_descrption.infer_state}
|
||||
return model_inputs
|
||||
|
||||
def _prepare_inputs_for_new_token(self, new_token: torch.Tensor):
|
||||
@@ -129,8 +127,8 @@ class GenerateSchedule(PipelineSchedule):
|
||||
|
||||
def _init_infer_state_action(self) -> None:
|
||||
"""
|
||||
This action is only for no first stage, to load batch and init infer_state.
|
||||
1.Load micro_batch 2.Use the current micro_batch to init the current infer_state
|
||||
This action is only for no first stage, to load batch and init infer_state.
|
||||
1.Load micro_batch 2.Use the current micro_batch to init the current infer_state
|
||||
"""
|
||||
inputs_dict = self.load_micro_batch()
|
||||
self.mb_manager.add_descrption(inputs_dict)
|
||||
@@ -145,19 +143,19 @@ class GenerateSchedule(PipelineSchedule):
|
||||
if self.verbose and self.stage_manager.is_first_stage():
|
||||
torch.cuda.synchronize()
|
||||
self.timestamps[self.mb_manager.idx].append(time.time())
|
||||
interval_inputs = {'infer_state': self.mb_manager.cur_infer_state}
|
||||
interval_inputs = {"infer_state": self.mb_manager.cur_infer_state}
|
||||
output_dict = model_forward(model, inputs_dict, interval_inputs)
|
||||
|
||||
self.action_interval_buffer.hidden_states = output_dict['hidden_states']
|
||||
self.action_interval_buffer.hidden_states = output_dict["hidden_states"]
|
||||
|
||||
def _gen_token_action(self, model: Module):
|
||||
"""
|
||||
This action is only for first stage
|
||||
This action is only for first stage
|
||||
1.do the forward with hidden_states to generate new tokens 2.step to update
|
||||
"""
|
||||
hidden_states = self.action_interval_buffer.hidden_states
|
||||
assert hidden_states is not None, "When first stage in GENERATE phase, the hidden states should not be None"
|
||||
interval_inputs = {'hidden_states': hidden_states, 'infer_state': self.mb_manager.cur_infer_state}
|
||||
interval_inputs = {"hidden_states": hidden_states, "infer_state": self.mb_manager.cur_infer_state}
|
||||
logits = model_forward(model, None, interval_inputs)
|
||||
if self.verbose and self.stage_manager.is_first_stage():
|
||||
torch.cuda.synchronize()
|
||||
@@ -178,18 +176,18 @@ class GenerateSchedule(PipelineSchedule):
|
||||
new_token = self.action_interval_buffer.new_token
|
||||
assert new_token is not None, "When first stage in GENERATE phase, the new token should not be None"
|
||||
inputs_dict = self._prepare_inputs_for_new_token(new_token)
|
||||
interval_inputs = {'infer_state': self.mb_manager.cur_infer_state}
|
||||
interval_inputs = {"infer_state": self.mb_manager.cur_infer_state}
|
||||
output_dict = model_forward(model, inputs_dict, interval_inputs)
|
||||
|
||||
self.action_interval_buffer.hidden_states = output_dict['hidden_states']
|
||||
self.action_interval_buffer.hidden_states = output_dict["hidden_states"]
|
||||
|
||||
def _body_encoding_action(self, model: Module):
|
||||
hidden_states = self.action_interval_buffer.hidden_states
|
||||
assert hidden_states is not None, "When not first stage, the hidden states should not be None"
|
||||
interval_inputs = {'hidden_states': hidden_states, 'infer_state': self.mb_manager.cur_infer_state}
|
||||
interval_inputs = {"hidden_states": hidden_states, "infer_state": self.mb_manager.cur_infer_state}
|
||||
output_dict = model_forward(model, None, interval_inputs)
|
||||
|
||||
self.action_interval_buffer.hidden_states = output_dict['hidden_states']
|
||||
self.action_interval_buffer.hidden_states = output_dict["hidden_states"]
|
||||
|
||||
def _comm_action(self, recv_pre: bool) -> torch.Tensor:
|
||||
"""
|
||||
@@ -233,12 +231,73 @@ class GenerateSchedule(PipelineSchedule):
|
||||
|
||||
return actions
|
||||
|
||||
def _gen_one_stage_action(self, model: Module):
|
||||
"""
|
||||
In this function, it will generate a sequence action for current state, and do the action one by one.
|
||||
|
||||
Args:
|
||||
model (Module): Model to be run.
|
||||
|
||||
Returns:
|
||||
List[Callable]: A list of action, each action is a callable function, and it will be called in order.
|
||||
"""
|
||||
actions = []
|
||||
|
||||
if self.mb_manager.cur_state is Status.PREFILL:
|
||||
actions.append(partial(self._load_stage_action, model))
|
||||
elif self.mb_manager.cur_state is Status.GENERATE:
|
||||
actions.append(partial(self._gen_token_action, model))
|
||||
actions.append(partial(self._head_encoding_action, model))
|
||||
elif self.mb_manager.cur_state is Status.COOLDOWN:
|
||||
actions.append(partial(self._gen_token_action, model))
|
||||
|
||||
return actions
|
||||
|
||||
def generate_step(self, model: Module, data_iter: Iterable) -> Union[torch.Tensor, dict]:
|
||||
if self.stage_manager.num_stages == 2:
|
||||
if self.stage_manager.num_stages == 1:
|
||||
return self.generate_step_one_stage(model, data_iter)
|
||||
elif self.stage_manager.num_stages == 2:
|
||||
return self.generate_step_p2p(model, data_iter)
|
||||
else:
|
||||
return self.generate_step_broadcast(model, data_iter)
|
||||
|
||||
@torch.no_grad()
|
||||
def generate_step_one_stage(self, model: Module, data_iter: Iterable) -> Union[torch.Tensor, dict]:
|
||||
"""
|
||||
Forward one step of the pipeline, when pipeline size is 1.
|
||||
|
||||
Args:
|
||||
model (Module): Model to be run.
|
||||
data_iter (Iterable): Data iterator.
|
||||
|
||||
Returns:
|
||||
Union[torch.Tensor, dict]: The intermediate output (dict) of the current stage. If it is the last stage, the output is the loss (Tensor).
|
||||
"""
|
||||
output_sequence = []
|
||||
self.load_batch(data_iter)
|
||||
model.eval()
|
||||
self.comm_dtype = model.dtype
|
||||
|
||||
whole_timestamp = []
|
||||
|
||||
# run by round
|
||||
for _ in range(self.round):
|
||||
self.timestamps = [[] for _ in range(self.stage_manager.num_stages)] if self.verbose else None
|
||||
self.action_interval_buffer.clear()
|
||||
while self.mb_manager.is_micro_batch_done() is False:
|
||||
actions = self._gen_one_stage_action(model)
|
||||
for action in actions:
|
||||
action()
|
||||
self.mb_manager.next()
|
||||
# All microbatch in current round is DONE
|
||||
output_sequence.extend(self.mb_manager.export_new_tokens())
|
||||
|
||||
self.mb_manager.clear()
|
||||
if self.verbose:
|
||||
whole_timestamp.extend(self.timestamps)
|
||||
|
||||
return output_sequence, whole_timestamp
|
||||
|
||||
@torch.no_grad()
|
||||
def generate_step_p2p(self, model: Module, data_iter: Iterable) -> Union[torch.Tensor, dict]:
|
||||
"""
|
||||
@@ -319,7 +378,7 @@ class GenerateSchedule(PipelineSchedule):
|
||||
torch.cuda.synchronize()
|
||||
self.timestamps[self.mb_manager.idx].append(time.time())
|
||||
self.mb_manager.add_descrption(inputs_dict)
|
||||
interval_inputs = {'infer_state': self.mb_manager.cur_infer_state}
|
||||
interval_inputs = {"infer_state": self.mb_manager.cur_infer_state}
|
||||
output_dict = model_forward(model, inputs_dict, interval_inputs)
|
||||
# In GENERATE phase
|
||||
else:
|
||||
@@ -330,18 +389,23 @@ class GenerateSchedule(PipelineSchedule):
|
||||
assert (
|
||||
hidden_states is not None
|
||||
), "When first stage in GENERATE phase, the hidden states should not be None"
|
||||
interval_inputs = {'hidden_states': hidden_states['hidden_states'], 'infer_state': self.mb_manager.cur_infer_state}
|
||||
interval_inputs = {
|
||||
"hidden_states": hidden_states["hidden_states"],
|
||||
"infer_state": self.mb_manager.cur_infer_state,
|
||||
}
|
||||
logits = model_forward(model, None, interval_inputs)
|
||||
if self.verbose and self.stage_manager.is_first_stage():
|
||||
torch.cuda.synchronize()
|
||||
self.timestamps[self.mb_manager.idx].append(time.time())
|
||||
assert 'logits' in logits, f"When first stage in GENERATE phase, the ouput should have attribute `logits`, but has {logits.keys()}"
|
||||
new_token = self._get_token_id(logits['logits'])
|
||||
assert (
|
||||
"logits" in logits
|
||||
), f"When first stage in GENERATE phase, the ouput should have attribute `logits`, but has {logits.keys()}"
|
||||
new_token = self._get_token_id(logits["logits"])
|
||||
self.mb_manager.step(new_token)
|
||||
# If the current micro batch is not DONE, go through blocks
|
||||
if self.mb_manager.cur_state in (Status.GENERATE, Status.COOLDOWN):
|
||||
inputs_dict = self._prepare_inputs_for_new_token(new_token)
|
||||
interval_inputs = {'infer_state': self.mb_manager.cur_infer_state}
|
||||
interval_inputs = {"infer_state": self.mb_manager.cur_infer_state}
|
||||
output_dict = model_forward(model, inputs_dict, interval_inputs)
|
||||
else:
|
||||
assert hidden_states is not None, "When not first stage, the hidden states should not be None"
|
||||
@@ -350,7 +414,10 @@ class GenerateSchedule(PipelineSchedule):
|
||||
if self.mb_manager.cur_state is Status.PREFILL:
|
||||
inputs_dict = self.load_micro_batch()
|
||||
self.mb_manager.add_descrption(inputs_dict)
|
||||
interval_inputs = {'hidden_states': hidden_states['hidden_states'], 'infer_state': self.mb_manager.cur_infer_state}
|
||||
interval_inputs = {
|
||||
"hidden_states": hidden_states["hidden_states"],
|
||||
"infer_state": self.mb_manager.cur_infer_state,
|
||||
}
|
||||
output_dict = model_forward(model, inputs_dict, interval_inputs)
|
||||
|
||||
# Current microbatch is not DONE, send hidden_state to next stage
|
||||
|
Reference in New Issue
Block a user