[inference] Refactor inference architecture (#5057)

* [inference] support only TP (#4998)

* support only tp

* enable tp

* add support for bloom (#5008)

* [refactor] refactor gptq and smoothquant llama (#5012)

* refactor gptq and smoothquant llama

* fix import error

* fix linear import torch-int

* fix smoothquant llama import error

* fix import accelerate error

* fix bug

* fix import smooth cuda

* fix smoothcuda

* [Inference Refactor] Merge chatglm2 with pp and tp (#5023)

merge chatglm with pp and tp

* [Refactor] remove useless inference code (#5022)

* remove useless code

* fix quant model

* fix test import bug

* mv original inference legacy

* fix chatglm2

* [Refactor] refactor policy search and quant type controlling in inference (#5035)

* [Refactor] refactor policy search and quant type controling in inference

* [inference] update readme (#5051)

* update readme

* update readme

* fix architecture

* fix table

* fix table

* [inference] udpate example (#5053)

* udpate example

* fix run.sh

* fix rebase bug

* fix some errors

* update readme

* add some features

* update interface

* update readme

* update benchmark

* add requirements-infer

---------

Co-authored-by: Bin Jia <45593998+FoolPlayer@users.noreply.github.com>
Co-authored-by: Zhongkai Zhao <kanezz620@gmail.com>
This commit is contained in:
Xu Kai
2023-11-19 21:05:05 +08:00
committed by GitHub
parent bc09b95f50
commit fd6482ad8c
115 changed files with 6027 additions and 1431 deletions

View File

@@ -7,7 +7,7 @@ import torch.cuda
from torch.nn import Module
from torch.utils._pytree import tree_map
from colossalai.inference.pipeline.microbatch_manager import MicroBatchManager, Status
from colossalai.inference.engine.microbatch_manager import MicroBatchManager, Status
from colossalai.pipeline.p2p import PipelineP2PCommunication
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.utils.cuda import get_current_device
@@ -93,9 +93,7 @@ class GenerateSchedule(PipelineSchedule):
Returns:
dict: inputs for interval stage, `{'past_key_values': torch.Tensor}` or `None`
"""
model_inputs = {
'infer_state': self.mb_manager.cur_descrption.infer_state
}
model_inputs = {"infer_state": self.mb_manager.cur_descrption.infer_state}
return model_inputs
def _prepare_inputs_for_new_token(self, new_token: torch.Tensor):
@@ -129,8 +127,8 @@ class GenerateSchedule(PipelineSchedule):
def _init_infer_state_action(self) -> None:
"""
This action is only for no first stage, to load batch and init infer_state.
1.Load micro_batch 2.Use the current micro_batch to init the current infer_state
This action is only for no first stage, to load batch and init infer_state.
1.Load micro_batch 2.Use the current micro_batch to init the current infer_state
"""
inputs_dict = self.load_micro_batch()
self.mb_manager.add_descrption(inputs_dict)
@@ -145,19 +143,19 @@ class GenerateSchedule(PipelineSchedule):
if self.verbose and self.stage_manager.is_first_stage():
torch.cuda.synchronize()
self.timestamps[self.mb_manager.idx].append(time.time())
interval_inputs = {'infer_state': self.mb_manager.cur_infer_state}
interval_inputs = {"infer_state": self.mb_manager.cur_infer_state}
output_dict = model_forward(model, inputs_dict, interval_inputs)
self.action_interval_buffer.hidden_states = output_dict['hidden_states']
self.action_interval_buffer.hidden_states = output_dict["hidden_states"]
def _gen_token_action(self, model: Module):
"""
This action is only for first stage
This action is only for first stage
1.do the forward with hidden_states to generate new tokens 2.step to update
"""
hidden_states = self.action_interval_buffer.hidden_states
assert hidden_states is not None, "When first stage in GENERATE phase, the hidden states should not be None"
interval_inputs = {'hidden_states': hidden_states, 'infer_state': self.mb_manager.cur_infer_state}
interval_inputs = {"hidden_states": hidden_states, "infer_state": self.mb_manager.cur_infer_state}
logits = model_forward(model, None, interval_inputs)
if self.verbose and self.stage_manager.is_first_stage():
torch.cuda.synchronize()
@@ -178,18 +176,18 @@ class GenerateSchedule(PipelineSchedule):
new_token = self.action_interval_buffer.new_token
assert new_token is not None, "When first stage in GENERATE phase, the new token should not be None"
inputs_dict = self._prepare_inputs_for_new_token(new_token)
interval_inputs = {'infer_state': self.mb_manager.cur_infer_state}
interval_inputs = {"infer_state": self.mb_manager.cur_infer_state}
output_dict = model_forward(model, inputs_dict, interval_inputs)
self.action_interval_buffer.hidden_states = output_dict['hidden_states']
self.action_interval_buffer.hidden_states = output_dict["hidden_states"]
def _body_encoding_action(self, model: Module):
hidden_states = self.action_interval_buffer.hidden_states
assert hidden_states is not None, "When not first stage, the hidden states should not be None"
interval_inputs = {'hidden_states': hidden_states, 'infer_state': self.mb_manager.cur_infer_state}
interval_inputs = {"hidden_states": hidden_states, "infer_state": self.mb_manager.cur_infer_state}
output_dict = model_forward(model, None, interval_inputs)
self.action_interval_buffer.hidden_states = output_dict['hidden_states']
self.action_interval_buffer.hidden_states = output_dict["hidden_states"]
def _comm_action(self, recv_pre: bool) -> torch.Tensor:
"""
@@ -233,12 +231,73 @@ class GenerateSchedule(PipelineSchedule):
return actions
def _gen_one_stage_action(self, model: Module):
"""
In this function, it will generate a sequence action for current state, and do the action one by one.
Args:
model (Module): Model to be run.
Returns:
List[Callable]: A list of action, each action is a callable function, and it will be called in order.
"""
actions = []
if self.mb_manager.cur_state is Status.PREFILL:
actions.append(partial(self._load_stage_action, model))
elif self.mb_manager.cur_state is Status.GENERATE:
actions.append(partial(self._gen_token_action, model))
actions.append(partial(self._head_encoding_action, model))
elif self.mb_manager.cur_state is Status.COOLDOWN:
actions.append(partial(self._gen_token_action, model))
return actions
def generate_step(self, model: Module, data_iter: Iterable) -> Union[torch.Tensor, dict]:
if self.stage_manager.num_stages == 2:
if self.stage_manager.num_stages == 1:
return self.generate_step_one_stage(model, data_iter)
elif self.stage_manager.num_stages == 2:
return self.generate_step_p2p(model, data_iter)
else:
return self.generate_step_broadcast(model, data_iter)
@torch.no_grad()
def generate_step_one_stage(self, model: Module, data_iter: Iterable) -> Union[torch.Tensor, dict]:
"""
Forward one step of the pipeline, when pipeline size is 1.
Args:
model (Module): Model to be run.
data_iter (Iterable): Data iterator.
Returns:
Union[torch.Tensor, dict]: The intermediate output (dict) of the current stage. If it is the last stage, the output is the loss (Tensor).
"""
output_sequence = []
self.load_batch(data_iter)
model.eval()
self.comm_dtype = model.dtype
whole_timestamp = []
# run by round
for _ in range(self.round):
self.timestamps = [[] for _ in range(self.stage_manager.num_stages)] if self.verbose else None
self.action_interval_buffer.clear()
while self.mb_manager.is_micro_batch_done() is False:
actions = self._gen_one_stage_action(model)
for action in actions:
action()
self.mb_manager.next()
# All microbatch in current round is DONE
output_sequence.extend(self.mb_manager.export_new_tokens())
self.mb_manager.clear()
if self.verbose:
whole_timestamp.extend(self.timestamps)
return output_sequence, whole_timestamp
@torch.no_grad()
def generate_step_p2p(self, model: Module, data_iter: Iterable) -> Union[torch.Tensor, dict]:
"""
@@ -319,7 +378,7 @@ class GenerateSchedule(PipelineSchedule):
torch.cuda.synchronize()
self.timestamps[self.mb_manager.idx].append(time.time())
self.mb_manager.add_descrption(inputs_dict)
interval_inputs = {'infer_state': self.mb_manager.cur_infer_state}
interval_inputs = {"infer_state": self.mb_manager.cur_infer_state}
output_dict = model_forward(model, inputs_dict, interval_inputs)
# In GENERATE phase
else:
@@ -330,18 +389,23 @@ class GenerateSchedule(PipelineSchedule):
assert (
hidden_states is not None
), "When first stage in GENERATE phase, the hidden states should not be None"
interval_inputs = {'hidden_states': hidden_states['hidden_states'], 'infer_state': self.mb_manager.cur_infer_state}
interval_inputs = {
"hidden_states": hidden_states["hidden_states"],
"infer_state": self.mb_manager.cur_infer_state,
}
logits = model_forward(model, None, interval_inputs)
if self.verbose and self.stage_manager.is_first_stage():
torch.cuda.synchronize()
self.timestamps[self.mb_manager.idx].append(time.time())
assert 'logits' in logits, f"When first stage in GENERATE phase, the ouput should have attribute `logits`, but has {logits.keys()}"
new_token = self._get_token_id(logits['logits'])
assert (
"logits" in logits
), f"When first stage in GENERATE phase, the ouput should have attribute `logits`, but has {logits.keys()}"
new_token = self._get_token_id(logits["logits"])
self.mb_manager.step(new_token)
# If the current micro batch is not DONE, go through blocks
if self.mb_manager.cur_state in (Status.GENERATE, Status.COOLDOWN):
inputs_dict = self._prepare_inputs_for_new_token(new_token)
interval_inputs = {'infer_state': self.mb_manager.cur_infer_state}
interval_inputs = {"infer_state": self.mb_manager.cur_infer_state}
output_dict = model_forward(model, inputs_dict, interval_inputs)
else:
assert hidden_states is not None, "When not first stage, the hidden states should not be None"
@@ -350,7 +414,10 @@ class GenerateSchedule(PipelineSchedule):
if self.mb_manager.cur_state is Status.PREFILL:
inputs_dict = self.load_micro_batch()
self.mb_manager.add_descrption(inputs_dict)
interval_inputs = {'hidden_states': hidden_states['hidden_states'], 'infer_state': self.mb_manager.cur_infer_state}
interval_inputs = {
"hidden_states": hidden_states["hidden_states"],
"infer_state": self.mb_manager.cur_infer_state,
}
output_dict = model_forward(model, inputs_dict, interval_inputs)
# Current microbatch is not DONE, send hidden_state to next stage