mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-11 05:49:55 +00:00
[Pipeline inference] Combine kvcache with pipeline inference (#4938)
* merge kvcache with pipeline inference and refactor the code structure * support ppsize > 2 * refactor pipeline code * do pre-commit * modify benchmark * fix bench mark * polish code * add docstring and update readme * refactor the code * fix some logic bug of ppinfer * polish readme * fix typo * skip infer test
This commit is contained in:
@@ -93,9 +93,9 @@ class GenerateSchedule(PipelineSchedule):
|
||||
Returns:
|
||||
dict: inputs for interval stage, `{'past_key_values': torch.Tensor}` or `None`
|
||||
"""
|
||||
model_inputs = (
|
||||
{"past_key_values": self.mb_manager.cur_kv_cache} if self.mb_manager.cur_kv_cache is not None else None
|
||||
)
|
||||
model_inputs = {
|
||||
'infer_state': self.mb_manager.cur_descrption.infer_state
|
||||
}
|
||||
return model_inputs
|
||||
|
||||
def _prepare_inputs_for_new_token(self, new_token: torch.Tensor):
|
||||
@@ -108,9 +108,8 @@ class GenerateSchedule(PipelineSchedule):
|
||||
dict: inputs for new token, `{'input_ids': torch.Tensor, 'attention_mask': torch.Tensor, 'past_key_values': torch.Tensor}`
|
||||
"""
|
||||
new_mask = self.mb_manager.cur_descrption.attn_mask
|
||||
past_key_values = self.mb_manager.cur_descrption.kv_cache
|
||||
|
||||
return dict(input_ids=new_token, attention_mask=new_mask, past_key_values=past_key_values)
|
||||
return dict(input_ids=new_token, attention_mask=new_mask)
|
||||
|
||||
def _get_token_id(self, hidden_state: torch.Tensor) -> torch.Tensor:
|
||||
last_hidden_state = hidden_state[:, -1]
|
||||
@@ -128,27 +127,38 @@ class GenerateSchedule(PipelineSchedule):
|
||||
return self.comm.p2p_recv()
|
||||
return self.comm.recv_forward()
|
||||
|
||||
def _load_stage_action(self, model: Module) -> None:
|
||||
def _init_infer_state_action(self) -> None:
|
||||
"""
|
||||
In this action, 1.load micro_batch 2.do the forward 3.step to update
|
||||
This action is only for no first stage, to load batch and init infer_state.
|
||||
1.Load micro_batch 2.Use the current micro_batch to init the current infer_state
|
||||
"""
|
||||
inputs_dict = self.load_micro_batch()
|
||||
self.mb_manager.add_descrption(inputs_dict)
|
||||
|
||||
def _load_stage_action(self, model: Module) -> None:
|
||||
"""
|
||||
This action is only for first stage, load, init and do forward.
|
||||
1.load micro_batch 2.do the forward 3.step to update
|
||||
"""
|
||||
inputs_dict = self.load_micro_batch()
|
||||
self.mb_manager.add_descrption(inputs_dict)
|
||||
if self.verbose and self.stage_manager.is_first_stage():
|
||||
torch.cuda.synchronize()
|
||||
self.timestamps[self.mb_manager.idx].append(time.time())
|
||||
output_dict = model_forward(model, inputs_dict, None)
|
||||
interval_inputs = {'infer_state': self.mb_manager.cur_infer_state}
|
||||
output_dict = model_forward(model, inputs_dict, interval_inputs)
|
||||
|
||||
self.mb_manager.step(inputs_dict, output_dict, None)
|
||||
self.action_interval_buffer.hidden_states = output_dict["hidden_states"]
|
||||
self.action_interval_buffer.hidden_states = output_dict['hidden_states']
|
||||
|
||||
def _gen_token_action(self, model: Module):
|
||||
"""
|
||||
In this action, 1.do the forward with hidden_states to generate new tokens 2.step to update
|
||||
This action is only for first stage
|
||||
1.do the forward with hidden_states to generate new tokens 2.step to update
|
||||
"""
|
||||
hidden_states = self.action_interval_buffer.hidden_states
|
||||
assert hidden_states is not None, "When first stage in GENERATE phase, the hidden states should not be None"
|
||||
hidden_states = {"hidden_states": hidden_states}
|
||||
logits = model_forward(model, None, hidden_states)
|
||||
interval_inputs = {'hidden_states': hidden_states, 'infer_state': self.mb_manager.cur_infer_state}
|
||||
logits = model_forward(model, None, interval_inputs)
|
||||
if self.verbose and self.stage_manager.is_first_stage():
|
||||
torch.cuda.synchronize()
|
||||
self.timestamps[self.mb_manager.idx].append(time.time())
|
||||
@@ -157,7 +167,7 @@ class GenerateSchedule(PipelineSchedule):
|
||||
), f"When first stage in GENERATE phase, the ouput should have attribute `logits`, but has {logits.keys()}"
|
||||
new_token = self._get_token_id(logits["logits"])
|
||||
|
||||
self.mb_manager.step(None, None, new_token)
|
||||
self.mb_manager.step(new_token)
|
||||
self.action_interval_buffer.new_token = new_token
|
||||
self.action_interval_buffer.hidden_states = None
|
||||
|
||||
@@ -168,20 +178,18 @@ class GenerateSchedule(PipelineSchedule):
|
||||
new_token = self.action_interval_buffer.new_token
|
||||
assert new_token is not None, "When first stage in GENERATE phase, the new token should not be None"
|
||||
inputs_dict = self._prepare_inputs_for_new_token(new_token)
|
||||
output_dict = model_forward(model, inputs_dict, None)
|
||||
interval_inputs = {'infer_state': self.mb_manager.cur_infer_state}
|
||||
output_dict = model_forward(model, inputs_dict, interval_inputs)
|
||||
|
||||
self.mb_manager.step(inputs_dict, output_dict, None)
|
||||
self.action_interval_buffer.hidden_states = output_dict["hidden_states"]
|
||||
self.action_interval_buffer.hidden_states = output_dict['hidden_states']
|
||||
|
||||
def _body_encoding_action(self, model: Module):
|
||||
hidden_states = self.action_interval_buffer.hidden_states
|
||||
assert hidden_states is not None, "When not first stage, the hidden states should not be None"
|
||||
inputs_dict = self._prepare_inputs_for_interval_stage()
|
||||
hidden_states = {"hidden_states": hidden_states}
|
||||
output_dict = model_forward(model, inputs_dict, hidden_states)
|
||||
interval_inputs = {'hidden_states': hidden_states, 'infer_state': self.mb_manager.cur_infer_state}
|
||||
output_dict = model_forward(model, None, interval_inputs)
|
||||
|
||||
self.mb_manager.step(inputs_dict, output_dict, None)
|
||||
self.action_interval_buffer.hidden_states = output_dict["hidden_states"]
|
||||
self.action_interval_buffer.hidden_states = output_dict['hidden_states']
|
||||
|
||||
def _comm_action(self, recv_pre: bool) -> torch.Tensor:
|
||||
"""
|
||||
@@ -218,6 +226,8 @@ class GenerateSchedule(PipelineSchedule):
|
||||
actions.append(partial(self._gen_token_action, model))
|
||||
# other stage
|
||||
else:
|
||||
if self.mb_manager.cur_state is Status.PREFILL:
|
||||
actions.append(partial(self._init_infer_state_action))
|
||||
actions.append(partial(self._comm_action, True))
|
||||
actions.append(partial(self._body_encoding_action, model))
|
||||
|
||||
@@ -308,8 +318,9 @@ class GenerateSchedule(PipelineSchedule):
|
||||
if self.verbose and self.stage_manager.is_first_stage():
|
||||
torch.cuda.synchronize()
|
||||
self.timestamps[self.mb_manager.idx].append(time.time())
|
||||
output_dict = model_forward(model, inputs_dict, None)
|
||||
self.mb_manager.step(inputs_dict, output_dict, None)
|
||||
self.mb_manager.add_descrption(inputs_dict)
|
||||
interval_inputs = {'infer_state': self.mb_manager.cur_infer_state}
|
||||
output_dict = model_forward(model, inputs_dict, interval_inputs)
|
||||
# In GENERATE phase
|
||||
else:
|
||||
# Get hidden_states from previous stage
|
||||
@@ -319,25 +330,28 @@ class GenerateSchedule(PipelineSchedule):
|
||||
assert (
|
||||
hidden_states is not None
|
||||
), "When first stage in GENERATE phase, the hidden states should not be None"
|
||||
logits = model_forward(model, None, hidden_states)
|
||||
interval_inputs = {'hidden_states': hidden_states['hidden_states'], 'infer_state': self.mb_manager.cur_infer_state}
|
||||
logits = model_forward(model, None, interval_inputs)
|
||||
if self.verbose and self.stage_manager.is_first_stage():
|
||||
torch.cuda.synchronize()
|
||||
self.timestamps[self.mb_manager.idx].append(time.time())
|
||||
assert (
|
||||
"logits" in logits
|
||||
), f"When first stage in GENERATE phase, the ouput should have attribute `logits`, but has {logits.keys()}"
|
||||
new_token = self._get_token_id(logits["logits"])
|
||||
self.mb_manager.step(None, None, new_token)
|
||||
assert 'logits' in logits, f"When first stage in GENERATE phase, the ouput should have attribute `logits`, but has {logits.keys()}"
|
||||
new_token = self._get_token_id(logits['logits'])
|
||||
self.mb_manager.step(new_token)
|
||||
# If the current micro batch is not DONE, go through blocks
|
||||
if self.mb_manager.cur_state in (Status.GENERATE, Status.COOLDOWN):
|
||||
inputs_dict = self._prepare_inputs_for_new_token(new_token)
|
||||
output_dict = model_forward(model, inputs_dict, None)
|
||||
self.mb_manager.step(inputs_dict, output_dict, None)
|
||||
interval_inputs = {'infer_state': self.mb_manager.cur_infer_state}
|
||||
output_dict = model_forward(model, inputs_dict, interval_inputs)
|
||||
else:
|
||||
assert hidden_states is not None, "When not first stage, the hidden states should not be None"
|
||||
inputs_dict = self._prepare_inputs_for_interval_stage()
|
||||
output_dict = model_forward(model, inputs_dict, hidden_states)
|
||||
self.mb_manager.step(inputs_dict, output_dict, None)
|
||||
# inputs_dict = self._prepare_inputs_for_interval_stage()
|
||||
inputs_dict = None
|
||||
if self.mb_manager.cur_state is Status.PREFILL:
|
||||
inputs_dict = self.load_micro_batch()
|
||||
self.mb_manager.add_descrption(inputs_dict)
|
||||
interval_inputs = {'hidden_states': hidden_states['hidden_states'], 'infer_state': self.mb_manager.cur_infer_state}
|
||||
output_dict = model_forward(model, inputs_dict, interval_inputs)
|
||||
|
||||
# Current microbatch is not DONE, send hidden_state to next stage
|
||||
if not self.stage_manager.is_first_stage() or self.mb_manager.cur_state in (
|
||||
|
Reference in New Issue
Block a user