[inference] support only TP (#4998)

* support only tp

* enable tp
This commit is contained in:
Xu Kai 2023-11-01 16:33:30 +08:00 committed by FoolPlayer
parent f71e63b0f3
commit f747d13040
4 changed files with 120 additions and 41 deletions

View File

@ -85,8 +85,6 @@ class CaiInferEngine:
assert max_batch_size <= 64, "Max batch size exceeds the constraint"
assert max_input_len + max_output_len <= 4096, "Max length exceeds the constraint"
# TODO: support only tensor parallel inference
assert pp_size > 1, "Not support only tensor parallel inference."
self.pp_size = pp_size
self.tp_size = tp_size
@ -102,23 +100,21 @@ class CaiInferEngine:
# Init pg mesh
pg_mesh = ProcessGroupMesh(pp_size, tp_size)
stage_manager = None
if pp_size > 1:
stage_manager = PipelineStageManager(pg_mesh, PP_AXIS, True)
self.cache_manager_list = [
self._init_manager(model, max_batch_size, max_input_len, max_output_len)
for _ in range(micro_batch_buffer_size or pp_size)
]
self.mb_manager = MicroBatchManager(
stage_manager.stage,
micro_batch_size,
micro_batch_buffer_size or pp_size,
max_input_len,
max_output_len,
self.cache_manager_list,
)
self.verbose = verbose
self.schedule = GenerateSchedule(stage_manager, self.mb_manager, verbose)
stage_manager = PipelineStageManager(pg_mesh, PP_AXIS, True)
self.cache_manager_list = [
self._init_manager(model, max_batch_size, max_input_len, max_output_len)
for _ in range(micro_batch_buffer_size or pp_size)
]
self.mb_manager = MicroBatchManager(
stage_manager.stage,
micro_batch_size,
micro_batch_buffer_size or pp_size,
max_input_len,
max_output_len,
self.cache_manager_list,
)
self.verbose = verbose
self.schedule = GenerateSchedule(stage_manager, self.mb_manager, verbose)
self.model = self._shardformer(model, model_policy, stage_manager, pg_mesh.get_group_along_axis(TP_AXIS))
@ -146,7 +142,7 @@ class CaiInferEngine:
shardconfig = ShardConfig(
tensor_parallel_process_group=tp_group,
pipeline_stage_manager=stage_manager,
enable_tensor_parallelism=False,
enable_tensor_parallelism=True if self.tp_size > 1 else False,
enable_fused_normalization=False,
enable_all_optimization=False,
enable_flash_attention=False,

View File

@ -103,4 +103,4 @@ class MemoryManager:
self.available_size = len(self.mem_state)
self.mem_state[:] = 1
self.max_len_in_batch = 0
self.logger.info("freed all space of memory manager")
# self.logger.info("freed all space of memory manager")

View File

@ -93,9 +93,7 @@ class GenerateSchedule(PipelineSchedule):
Returns:
dict: inputs for interval stage, `{'past_key_values': torch.Tensor}` or `None`
"""
model_inputs = {
'infer_state': self.mb_manager.cur_descrption.infer_state
}
model_inputs = {"infer_state": self.mb_manager.cur_descrption.infer_state}
return model_inputs
def _prepare_inputs_for_new_token(self, new_token: torch.Tensor):
@ -129,8 +127,8 @@ class GenerateSchedule(PipelineSchedule):
def _init_infer_state_action(self) -> None:
"""
This action is only for no first stage, to load batch and init infer_state.
1.Load micro_batch 2.Use the current micro_batch to init the current infer_state
This action is only for no first stage, to load batch and init infer_state.
1.Load micro_batch 2.Use the current micro_batch to init the current infer_state
"""
inputs_dict = self.load_micro_batch()
self.mb_manager.add_descrption(inputs_dict)
@ -145,19 +143,19 @@ class GenerateSchedule(PipelineSchedule):
if self.verbose and self.stage_manager.is_first_stage():
torch.cuda.synchronize()
self.timestamps[self.mb_manager.idx].append(time.time())
interval_inputs = {'infer_state': self.mb_manager.cur_infer_state}
interval_inputs = {"infer_state": self.mb_manager.cur_infer_state}
output_dict = model_forward(model, inputs_dict, interval_inputs)
self.action_interval_buffer.hidden_states = output_dict['hidden_states']
self.action_interval_buffer.hidden_states = output_dict["hidden_states"]
def _gen_token_action(self, model: Module):
"""
This action is only for first stage
This action is only for first stage
1.do the forward with hidden_states to generate new tokens 2.step to update
"""
hidden_states = self.action_interval_buffer.hidden_states
assert hidden_states is not None, "When first stage in GENERATE phase, the hidden states should not be None"
interval_inputs = {'hidden_states': hidden_states, 'infer_state': self.mb_manager.cur_infer_state}
interval_inputs = {"hidden_states": hidden_states, "infer_state": self.mb_manager.cur_infer_state}
logits = model_forward(model, None, interval_inputs)
if self.verbose and self.stage_manager.is_first_stage():
torch.cuda.synchronize()
@ -178,18 +176,18 @@ class GenerateSchedule(PipelineSchedule):
new_token = self.action_interval_buffer.new_token
assert new_token is not None, "When first stage in GENERATE phase, the new token should not be None"
inputs_dict = self._prepare_inputs_for_new_token(new_token)
interval_inputs = {'infer_state': self.mb_manager.cur_infer_state}
interval_inputs = {"infer_state": self.mb_manager.cur_infer_state}
output_dict = model_forward(model, inputs_dict, interval_inputs)
self.action_interval_buffer.hidden_states = output_dict['hidden_states']
self.action_interval_buffer.hidden_states = output_dict["hidden_states"]
def _body_encoding_action(self, model: Module):
hidden_states = self.action_interval_buffer.hidden_states
assert hidden_states is not None, "When not first stage, the hidden states should not be None"
interval_inputs = {'hidden_states': hidden_states, 'infer_state': self.mb_manager.cur_infer_state}
interval_inputs = {"hidden_states": hidden_states, "infer_state": self.mb_manager.cur_infer_state}
output_dict = model_forward(model, None, interval_inputs)
self.action_interval_buffer.hidden_states = output_dict['hidden_states']
self.action_interval_buffer.hidden_states = output_dict["hidden_states"]
def _comm_action(self, recv_pre: bool) -> torch.Tensor:
"""
@ -233,12 +231,73 @@ class GenerateSchedule(PipelineSchedule):
return actions
def _gen_one_stage_action(self, model: Module):
"""
In this function, it will generate a sequence action for current state, and do the action one by one.
Args:
model (Module): Model to be run.
Returns:
List[Callable]: A list of action, each action is a callable function, and it will be called in order.
"""
actions = []
if self.mb_manager.cur_state is Status.PREFILL:
actions.append(partial(self._load_stage_action, model))
elif self.mb_manager.cur_state is Status.GENERATE:
actions.append(partial(self._gen_token_action, model))
actions.append(partial(self._head_encoding_action, model))
elif self.mb_manager.cur_state is Status.COOLDOWN:
actions.append(partial(self._gen_token_action, model))
return actions
def generate_step(self, model: Module, data_iter: Iterable) -> Union[torch.Tensor, dict]:
if self.stage_manager.num_stages == 2:
if self.stage_manager.num_stages == 1:
return self.generate_step_one_stage(model, data_iter)
elif self.stage_manager.num_stages == 2:
return self.generate_step_p2p(model, data_iter)
else:
return self.generate_step_broadcast(model, data_iter)
@torch.no_grad()
def generate_step_one_stage(self, model: Module, data_iter: Iterable) -> Union[torch.Tensor, dict]:
"""
Forward one step of the pipeline, when pipeline size is 1.
Args:
model (Module): Model to be run.
data_iter (Iterable): Data iterator.
Returns:
Union[torch.Tensor, dict]: The intermediate output (dict) of the current stage. If it is the last stage, the output is the loss (Tensor).
"""
output_sequence = []
self.load_batch(data_iter)
model.eval()
self.comm_dtype = model.dtype
whole_timestamp = []
# run by round
for _ in range(self.round):
self.timestamps = [[] for _ in range(self.stage_manager.num_stages)] if self.verbose else None
self.action_interval_buffer.clear()
while self.mb_manager.is_micro_batch_done() is False:
actions = self._gen_one_stage_action(model)
for action in actions:
action()
self.mb_manager.next()
# All microbatch in current round is DONE
output_sequence.extend(self.mb_manager.export_new_tokens())
self.mb_manager.clear()
if self.verbose:
whole_timestamp.extend(self.timestamps)
return output_sequence, whole_timestamp
@torch.no_grad()
def generate_step_p2p(self, model: Module, data_iter: Iterable) -> Union[torch.Tensor, dict]:
"""
@ -319,7 +378,7 @@ class GenerateSchedule(PipelineSchedule):
torch.cuda.synchronize()
self.timestamps[self.mb_manager.idx].append(time.time())
self.mb_manager.add_descrption(inputs_dict)
interval_inputs = {'infer_state': self.mb_manager.cur_infer_state}
interval_inputs = {"infer_state": self.mb_manager.cur_infer_state}
output_dict = model_forward(model, inputs_dict, interval_inputs)
# In GENERATE phase
else:
@ -330,18 +389,23 @@ class GenerateSchedule(PipelineSchedule):
assert (
hidden_states is not None
), "When first stage in GENERATE phase, the hidden states should not be None"
interval_inputs = {'hidden_states': hidden_states['hidden_states'], 'infer_state': self.mb_manager.cur_infer_state}
interval_inputs = {
"hidden_states": hidden_states["hidden_states"],
"infer_state": self.mb_manager.cur_infer_state,
}
logits = model_forward(model, None, interval_inputs)
if self.verbose and self.stage_manager.is_first_stage():
torch.cuda.synchronize()
self.timestamps[self.mb_manager.idx].append(time.time())
assert 'logits' in logits, f"When first stage in GENERATE phase, the ouput should have attribute `logits`, but has {logits.keys()}"
new_token = self._get_token_id(logits['logits'])
assert (
"logits" in logits
), f"When first stage in GENERATE phase, the ouput should have attribute `logits`, but has {logits.keys()}"
new_token = self._get_token_id(logits["logits"])
self.mb_manager.step(new_token)
# If the current micro batch is not DONE, go through blocks
if self.mb_manager.cur_state in (Status.GENERATE, Status.COOLDOWN):
inputs_dict = self._prepare_inputs_for_new_token(new_token)
interval_inputs = {'infer_state': self.mb_manager.cur_infer_state}
interval_inputs = {"infer_state": self.mb_manager.cur_infer_state}
output_dict = model_forward(model, inputs_dict, interval_inputs)
else:
assert hidden_states is not None, "When not first stage, the hidden states should not be None"
@ -350,7 +414,10 @@ class GenerateSchedule(PipelineSchedule):
if self.mb_manager.cur_state is Status.PREFILL:
inputs_dict = self.load_micro_batch()
self.mb_manager.add_descrption(inputs_dict)
interval_inputs = {'hidden_states': hidden_states['hidden_states'], 'infer_state': self.mb_manager.cur_infer_state}
interval_inputs = {
"hidden_states": hidden_states["hidden_states"],
"infer_state": self.mb_manager.cur_infer_state,
}
output_dict = model_forward(model, inputs_dict, interval_inputs)
# Current microbatch is not DONE, send hidden_state to next stage

View File

@ -65,6 +65,16 @@ def run_tp_pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch
torch.cuda.empty_cache()
@parameterize("tp_size", [2])
@parameterize("pp_size", [1])
@parameterize("max_output_len", [2])
@parameterize("micro_batch_size", [1])
@clear_cache_before_run()
def run_tp_inference_test(tp_size, pp_size, max_output_len, micro_batch_size):
pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size)
torch.cuda.empty_cache()
def check_pipeline_inference(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_pipeline_inference_test()
@ -75,6 +85,11 @@ def check_tp_pipeline_inference(rank, world_size, port):
run_tp_pipeline_inference_test()
def check_tp_inference(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_tp_inference_test()
@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5")
@pytest.mark.dist
@rerun_if_address_is_in_use()
@ -82,6 +97,7 @@ def check_tp_pipeline_inference(rank, world_size, port):
def test_pipeline_inference():
spawn(check_pipeline_inference, nprocs=2)
spawn(check_tp_pipeline_inference, nprocs=4)
spawn(check_tp_inference, nprocs=2)
if __name__ == "__main__":