[Pipeline inference] Combine kvcache with pipeline inference (#4938)

* merge kvcache with pipeline inference and refactor the code structure

* support ppsize > 2

* refactor pipeline code

* do pre-commit

* modify benchmark

* fix bench mark

* polish code

* add docstring and update readme

* refactor the code

* fix some logic bug of ppinfer

* polish readme

* fix typo

* skip infer test
This commit is contained in:
Bin Jia
2023-10-27 16:19:54 +08:00
committed by GitHub
parent c6cd629e7a
commit 1db6727678
19 changed files with 922 additions and 745 deletions

View File

@@ -93,9 +93,9 @@ class GenerateSchedule(PipelineSchedule):
Returns:
dict: inputs for interval stage, `{'past_key_values': torch.Tensor}` or `None`
"""
model_inputs = (
{"past_key_values": self.mb_manager.cur_kv_cache} if self.mb_manager.cur_kv_cache is not None else None
)
model_inputs = {
'infer_state': self.mb_manager.cur_descrption.infer_state
}
return model_inputs
def _prepare_inputs_for_new_token(self, new_token: torch.Tensor):
@@ -108,9 +108,8 @@ class GenerateSchedule(PipelineSchedule):
dict: inputs for new token, `{'input_ids': torch.Tensor, 'attention_mask': torch.Tensor, 'past_key_values': torch.Tensor}`
"""
new_mask = self.mb_manager.cur_descrption.attn_mask
past_key_values = self.mb_manager.cur_descrption.kv_cache
return dict(input_ids=new_token, attention_mask=new_mask, past_key_values=past_key_values)
return dict(input_ids=new_token, attention_mask=new_mask)
def _get_token_id(self, hidden_state: torch.Tensor) -> torch.Tensor:
last_hidden_state = hidden_state[:, -1]
@@ -128,27 +127,38 @@ class GenerateSchedule(PipelineSchedule):
return self.comm.p2p_recv()
return self.comm.recv_forward()
def _load_stage_action(self, model: Module) -> None:
def _init_infer_state_action(self) -> None:
"""
In this action, 1.load micro_batch 2.do the forward 3.step to update
This action is only for no first stage, to load batch and init infer_state.
1.Load micro_batch 2.Use the current micro_batch to init the current infer_state
"""
inputs_dict = self.load_micro_batch()
self.mb_manager.add_descrption(inputs_dict)
def _load_stage_action(self, model: Module) -> None:
"""
This action is only for first stage, load, init and do forward.
1.load micro_batch 2.do the forward 3.step to update
"""
inputs_dict = self.load_micro_batch()
self.mb_manager.add_descrption(inputs_dict)
if self.verbose and self.stage_manager.is_first_stage():
torch.cuda.synchronize()
self.timestamps[self.mb_manager.idx].append(time.time())
output_dict = model_forward(model, inputs_dict, None)
interval_inputs = {'infer_state': self.mb_manager.cur_infer_state}
output_dict = model_forward(model, inputs_dict, interval_inputs)
self.mb_manager.step(inputs_dict, output_dict, None)
self.action_interval_buffer.hidden_states = output_dict["hidden_states"]
self.action_interval_buffer.hidden_states = output_dict['hidden_states']
def _gen_token_action(self, model: Module):
"""
In this action, 1.do the forward with hidden_states to generate new tokens 2.step to update
This action is only for first stage
1.do the forward with hidden_states to generate new tokens 2.step to update
"""
hidden_states = self.action_interval_buffer.hidden_states
assert hidden_states is not None, "When first stage in GENERATE phase, the hidden states should not be None"
hidden_states = {"hidden_states": hidden_states}
logits = model_forward(model, None, hidden_states)
interval_inputs = {'hidden_states': hidden_states, 'infer_state': self.mb_manager.cur_infer_state}
logits = model_forward(model, None, interval_inputs)
if self.verbose and self.stage_manager.is_first_stage():
torch.cuda.synchronize()
self.timestamps[self.mb_manager.idx].append(time.time())
@@ -157,7 +167,7 @@ class GenerateSchedule(PipelineSchedule):
), f"When first stage in GENERATE phase, the ouput should have attribute `logits`, but has {logits.keys()}"
new_token = self._get_token_id(logits["logits"])
self.mb_manager.step(None, None, new_token)
self.mb_manager.step(new_token)
self.action_interval_buffer.new_token = new_token
self.action_interval_buffer.hidden_states = None
@@ -168,20 +178,18 @@ class GenerateSchedule(PipelineSchedule):
new_token = self.action_interval_buffer.new_token
assert new_token is not None, "When first stage in GENERATE phase, the new token should not be None"
inputs_dict = self._prepare_inputs_for_new_token(new_token)
output_dict = model_forward(model, inputs_dict, None)
interval_inputs = {'infer_state': self.mb_manager.cur_infer_state}
output_dict = model_forward(model, inputs_dict, interval_inputs)
self.mb_manager.step(inputs_dict, output_dict, None)
self.action_interval_buffer.hidden_states = output_dict["hidden_states"]
self.action_interval_buffer.hidden_states = output_dict['hidden_states']
def _body_encoding_action(self, model: Module):
hidden_states = self.action_interval_buffer.hidden_states
assert hidden_states is not None, "When not first stage, the hidden states should not be None"
inputs_dict = self._prepare_inputs_for_interval_stage()
hidden_states = {"hidden_states": hidden_states}
output_dict = model_forward(model, inputs_dict, hidden_states)
interval_inputs = {'hidden_states': hidden_states, 'infer_state': self.mb_manager.cur_infer_state}
output_dict = model_forward(model, None, interval_inputs)
self.mb_manager.step(inputs_dict, output_dict, None)
self.action_interval_buffer.hidden_states = output_dict["hidden_states"]
self.action_interval_buffer.hidden_states = output_dict['hidden_states']
def _comm_action(self, recv_pre: bool) -> torch.Tensor:
"""
@@ -218,6 +226,8 @@ class GenerateSchedule(PipelineSchedule):
actions.append(partial(self._gen_token_action, model))
# other stage
else:
if self.mb_manager.cur_state is Status.PREFILL:
actions.append(partial(self._init_infer_state_action))
actions.append(partial(self._comm_action, True))
actions.append(partial(self._body_encoding_action, model))
@@ -308,8 +318,9 @@ class GenerateSchedule(PipelineSchedule):
if self.verbose and self.stage_manager.is_first_stage():
torch.cuda.synchronize()
self.timestamps[self.mb_manager.idx].append(time.time())
output_dict = model_forward(model, inputs_dict, None)
self.mb_manager.step(inputs_dict, output_dict, None)
self.mb_manager.add_descrption(inputs_dict)
interval_inputs = {'infer_state': self.mb_manager.cur_infer_state}
output_dict = model_forward(model, inputs_dict, interval_inputs)
# In GENERATE phase
else:
# Get hidden_states from previous stage
@@ -319,25 +330,28 @@ class GenerateSchedule(PipelineSchedule):
assert (
hidden_states is not None
), "When first stage in GENERATE phase, the hidden states should not be None"
logits = model_forward(model, None, hidden_states)
interval_inputs = {'hidden_states': hidden_states['hidden_states'], 'infer_state': self.mb_manager.cur_infer_state}
logits = model_forward(model, None, interval_inputs)
if self.verbose and self.stage_manager.is_first_stage():
torch.cuda.synchronize()
self.timestamps[self.mb_manager.idx].append(time.time())
assert (
"logits" in logits
), f"When first stage in GENERATE phase, the ouput should have attribute `logits`, but has {logits.keys()}"
new_token = self._get_token_id(logits["logits"])
self.mb_manager.step(None, None, new_token)
assert 'logits' in logits, f"When first stage in GENERATE phase, the ouput should have attribute `logits`, but has {logits.keys()}"
new_token = self._get_token_id(logits['logits'])
self.mb_manager.step(new_token)
# If the current micro batch is not DONE, go through blocks
if self.mb_manager.cur_state in (Status.GENERATE, Status.COOLDOWN):
inputs_dict = self._prepare_inputs_for_new_token(new_token)
output_dict = model_forward(model, inputs_dict, None)
self.mb_manager.step(inputs_dict, output_dict, None)
interval_inputs = {'infer_state': self.mb_manager.cur_infer_state}
output_dict = model_forward(model, inputs_dict, interval_inputs)
else:
assert hidden_states is not None, "When not first stage, the hidden states should not be None"
inputs_dict = self._prepare_inputs_for_interval_stage()
output_dict = model_forward(model, inputs_dict, hidden_states)
self.mb_manager.step(inputs_dict, output_dict, None)
# inputs_dict = self._prepare_inputs_for_interval_stage()
inputs_dict = None
if self.mb_manager.cur_state is Status.PREFILL:
inputs_dict = self.load_micro_batch()
self.mb_manager.add_descrption(inputs_dict)
interval_inputs = {'hidden_states': hidden_states['hidden_states'], 'infer_state': self.mb_manager.cur_infer_state}
output_dict = model_forward(model, inputs_dict, interval_inputs)
# Current microbatch is not DONE, send hidden_state to next stage
if not self.stage_manager.is_first_stage() or self.mb_manager.cur_state in (