[inference] support only TP (#4998)

* support only tp

* enable tp
This commit is contained in:
Xu Kai 2023-11-01 16:33:30 +08:00 committed by FoolPlayer
parent f71e63b0f3
commit f747d13040
4 changed files with 120 additions and 41 deletions

View File

@ -85,8 +85,6 @@ class CaiInferEngine:
assert max_batch_size <= 64, "Max batch size exceeds the constraint" assert max_batch_size <= 64, "Max batch size exceeds the constraint"
assert max_input_len + max_output_len <= 4096, "Max length exceeds the constraint" assert max_input_len + max_output_len <= 4096, "Max length exceeds the constraint"
# TODO: support only tensor parallel inference
assert pp_size > 1, "Not support only tensor parallel inference."
self.pp_size = pp_size self.pp_size = pp_size
self.tp_size = tp_size self.tp_size = tp_size
@ -102,23 +100,21 @@ class CaiInferEngine:
# Init pg mesh # Init pg mesh
pg_mesh = ProcessGroupMesh(pp_size, tp_size) pg_mesh = ProcessGroupMesh(pp_size, tp_size)
stage_manager = None stage_manager = PipelineStageManager(pg_mesh, PP_AXIS, True)
if pp_size > 1: self.cache_manager_list = [
stage_manager = PipelineStageManager(pg_mesh, PP_AXIS, True) self._init_manager(model, max_batch_size, max_input_len, max_output_len)
self.cache_manager_list = [ for _ in range(micro_batch_buffer_size or pp_size)
self._init_manager(model, max_batch_size, max_input_len, max_output_len) ]
for _ in range(micro_batch_buffer_size or pp_size) self.mb_manager = MicroBatchManager(
] stage_manager.stage,
self.mb_manager = MicroBatchManager( micro_batch_size,
stage_manager.stage, micro_batch_buffer_size or pp_size,
micro_batch_size, max_input_len,
micro_batch_buffer_size or pp_size, max_output_len,
max_input_len, self.cache_manager_list,
max_output_len, )
self.cache_manager_list, self.verbose = verbose
) self.schedule = GenerateSchedule(stage_manager, self.mb_manager, verbose)
self.verbose = verbose
self.schedule = GenerateSchedule(stage_manager, self.mb_manager, verbose)
self.model = self._shardformer(model, model_policy, stage_manager, pg_mesh.get_group_along_axis(TP_AXIS)) self.model = self._shardformer(model, model_policy, stage_manager, pg_mesh.get_group_along_axis(TP_AXIS))
@ -146,7 +142,7 @@ class CaiInferEngine:
shardconfig = ShardConfig( shardconfig = ShardConfig(
tensor_parallel_process_group=tp_group, tensor_parallel_process_group=tp_group,
pipeline_stage_manager=stage_manager, pipeline_stage_manager=stage_manager,
enable_tensor_parallelism=False, enable_tensor_parallelism=True if self.tp_size > 1 else False,
enable_fused_normalization=False, enable_fused_normalization=False,
enable_all_optimization=False, enable_all_optimization=False,
enable_flash_attention=False, enable_flash_attention=False,

View File

@ -103,4 +103,4 @@ class MemoryManager:
self.available_size = len(self.mem_state) self.available_size = len(self.mem_state)
self.mem_state[:] = 1 self.mem_state[:] = 1
self.max_len_in_batch = 0 self.max_len_in_batch = 0
self.logger.info("freed all space of memory manager") # self.logger.info("freed all space of memory manager")

View File

@ -93,9 +93,7 @@ class GenerateSchedule(PipelineSchedule):
Returns: Returns:
dict: inputs for interval stage, `{'past_key_values': torch.Tensor}` or `None` dict: inputs for interval stage, `{'past_key_values': torch.Tensor}` or `None`
""" """
model_inputs = { model_inputs = {"infer_state": self.mb_manager.cur_descrption.infer_state}
'infer_state': self.mb_manager.cur_descrption.infer_state
}
return model_inputs return model_inputs
def _prepare_inputs_for_new_token(self, new_token: torch.Tensor): def _prepare_inputs_for_new_token(self, new_token: torch.Tensor):
@ -129,8 +127,8 @@ class GenerateSchedule(PipelineSchedule):
def _init_infer_state_action(self) -> None: def _init_infer_state_action(self) -> None:
""" """
This action is only for no first stage, to load batch and init infer_state. This action is only for no first stage, to load batch and init infer_state.
1.Load micro_batch 2.Use the current micro_batch to init the current infer_state 1.Load micro_batch 2.Use the current micro_batch to init the current infer_state
""" """
inputs_dict = self.load_micro_batch() inputs_dict = self.load_micro_batch()
self.mb_manager.add_descrption(inputs_dict) self.mb_manager.add_descrption(inputs_dict)
@ -145,19 +143,19 @@ class GenerateSchedule(PipelineSchedule):
if self.verbose and self.stage_manager.is_first_stage(): if self.verbose and self.stage_manager.is_first_stage():
torch.cuda.synchronize() torch.cuda.synchronize()
self.timestamps[self.mb_manager.idx].append(time.time()) self.timestamps[self.mb_manager.idx].append(time.time())
interval_inputs = {'infer_state': self.mb_manager.cur_infer_state} interval_inputs = {"infer_state": self.mb_manager.cur_infer_state}
output_dict = model_forward(model, inputs_dict, interval_inputs) output_dict = model_forward(model, inputs_dict, interval_inputs)
self.action_interval_buffer.hidden_states = output_dict['hidden_states'] self.action_interval_buffer.hidden_states = output_dict["hidden_states"]
def _gen_token_action(self, model: Module): def _gen_token_action(self, model: Module):
""" """
This action is only for first stage This action is only for first stage
1.do the forward with hidden_states to generate new tokens 2.step to update 1.do the forward with hidden_states to generate new tokens 2.step to update
""" """
hidden_states = self.action_interval_buffer.hidden_states hidden_states = self.action_interval_buffer.hidden_states
assert hidden_states is not None, "When first stage in GENERATE phase, the hidden states should not be None" assert hidden_states is not None, "When first stage in GENERATE phase, the hidden states should not be None"
interval_inputs = {'hidden_states': hidden_states, 'infer_state': self.mb_manager.cur_infer_state} interval_inputs = {"hidden_states": hidden_states, "infer_state": self.mb_manager.cur_infer_state}
logits = model_forward(model, None, interval_inputs) logits = model_forward(model, None, interval_inputs)
if self.verbose and self.stage_manager.is_first_stage(): if self.verbose and self.stage_manager.is_first_stage():
torch.cuda.synchronize() torch.cuda.synchronize()
@ -178,18 +176,18 @@ class GenerateSchedule(PipelineSchedule):
new_token = self.action_interval_buffer.new_token new_token = self.action_interval_buffer.new_token
assert new_token is not None, "When first stage in GENERATE phase, the new token should not be None" assert new_token is not None, "When first stage in GENERATE phase, the new token should not be None"
inputs_dict = self._prepare_inputs_for_new_token(new_token) inputs_dict = self._prepare_inputs_for_new_token(new_token)
interval_inputs = {'infer_state': self.mb_manager.cur_infer_state} interval_inputs = {"infer_state": self.mb_manager.cur_infer_state}
output_dict = model_forward(model, inputs_dict, interval_inputs) output_dict = model_forward(model, inputs_dict, interval_inputs)
self.action_interval_buffer.hidden_states = output_dict['hidden_states'] self.action_interval_buffer.hidden_states = output_dict["hidden_states"]
def _body_encoding_action(self, model: Module): def _body_encoding_action(self, model: Module):
hidden_states = self.action_interval_buffer.hidden_states hidden_states = self.action_interval_buffer.hidden_states
assert hidden_states is not None, "When not first stage, the hidden states should not be None" assert hidden_states is not None, "When not first stage, the hidden states should not be None"
interval_inputs = {'hidden_states': hidden_states, 'infer_state': self.mb_manager.cur_infer_state} interval_inputs = {"hidden_states": hidden_states, "infer_state": self.mb_manager.cur_infer_state}
output_dict = model_forward(model, None, interval_inputs) output_dict = model_forward(model, None, interval_inputs)
self.action_interval_buffer.hidden_states = output_dict['hidden_states'] self.action_interval_buffer.hidden_states = output_dict["hidden_states"]
def _comm_action(self, recv_pre: bool) -> torch.Tensor: def _comm_action(self, recv_pre: bool) -> torch.Tensor:
""" """
@ -233,12 +231,73 @@ class GenerateSchedule(PipelineSchedule):
return actions return actions
def _gen_one_stage_action(self, model: Module):
"""
In this function, it will generate a sequence action for current state, and do the action one by one.
Args:
model (Module): Model to be run.
Returns:
List[Callable]: A list of action, each action is a callable function, and it will be called in order.
"""
actions = []
if self.mb_manager.cur_state is Status.PREFILL:
actions.append(partial(self._load_stage_action, model))
elif self.mb_manager.cur_state is Status.GENERATE:
actions.append(partial(self._gen_token_action, model))
actions.append(partial(self._head_encoding_action, model))
elif self.mb_manager.cur_state is Status.COOLDOWN:
actions.append(partial(self._gen_token_action, model))
return actions
def generate_step(self, model: Module, data_iter: Iterable) -> Union[torch.Tensor, dict]: def generate_step(self, model: Module, data_iter: Iterable) -> Union[torch.Tensor, dict]:
if self.stage_manager.num_stages == 2: if self.stage_manager.num_stages == 1:
return self.generate_step_one_stage(model, data_iter)
elif self.stage_manager.num_stages == 2:
return self.generate_step_p2p(model, data_iter) return self.generate_step_p2p(model, data_iter)
else: else:
return self.generate_step_broadcast(model, data_iter) return self.generate_step_broadcast(model, data_iter)
@torch.no_grad()
def generate_step_one_stage(self, model: Module, data_iter: Iterable) -> Union[torch.Tensor, dict]:
"""
Forward one step of the pipeline, when pipeline size is 1.
Args:
model (Module): Model to be run.
data_iter (Iterable): Data iterator.
Returns:
Union[torch.Tensor, dict]: The intermediate output (dict) of the current stage. If it is the last stage, the output is the loss (Tensor).
"""
output_sequence = []
self.load_batch(data_iter)
model.eval()
self.comm_dtype = model.dtype
whole_timestamp = []
# run by round
for _ in range(self.round):
self.timestamps = [[] for _ in range(self.stage_manager.num_stages)] if self.verbose else None
self.action_interval_buffer.clear()
while self.mb_manager.is_micro_batch_done() is False:
actions = self._gen_one_stage_action(model)
for action in actions:
action()
self.mb_manager.next()
# All microbatch in current round is DONE
output_sequence.extend(self.mb_manager.export_new_tokens())
self.mb_manager.clear()
if self.verbose:
whole_timestamp.extend(self.timestamps)
return output_sequence, whole_timestamp
@torch.no_grad() @torch.no_grad()
def generate_step_p2p(self, model: Module, data_iter: Iterable) -> Union[torch.Tensor, dict]: def generate_step_p2p(self, model: Module, data_iter: Iterable) -> Union[torch.Tensor, dict]:
""" """
@ -319,7 +378,7 @@ class GenerateSchedule(PipelineSchedule):
torch.cuda.synchronize() torch.cuda.synchronize()
self.timestamps[self.mb_manager.idx].append(time.time()) self.timestamps[self.mb_manager.idx].append(time.time())
self.mb_manager.add_descrption(inputs_dict) self.mb_manager.add_descrption(inputs_dict)
interval_inputs = {'infer_state': self.mb_manager.cur_infer_state} interval_inputs = {"infer_state": self.mb_manager.cur_infer_state}
output_dict = model_forward(model, inputs_dict, interval_inputs) output_dict = model_forward(model, inputs_dict, interval_inputs)
# In GENERATE phase # In GENERATE phase
else: else:
@ -330,18 +389,23 @@ class GenerateSchedule(PipelineSchedule):
assert ( assert (
hidden_states is not None hidden_states is not None
), "When first stage in GENERATE phase, the hidden states should not be None" ), "When first stage in GENERATE phase, the hidden states should not be None"
interval_inputs = {'hidden_states': hidden_states['hidden_states'], 'infer_state': self.mb_manager.cur_infer_state} interval_inputs = {
"hidden_states": hidden_states["hidden_states"],
"infer_state": self.mb_manager.cur_infer_state,
}
logits = model_forward(model, None, interval_inputs) logits = model_forward(model, None, interval_inputs)
if self.verbose and self.stage_manager.is_first_stage(): if self.verbose and self.stage_manager.is_first_stage():
torch.cuda.synchronize() torch.cuda.synchronize()
self.timestamps[self.mb_manager.idx].append(time.time()) self.timestamps[self.mb_manager.idx].append(time.time())
assert 'logits' in logits, f"When first stage in GENERATE phase, the ouput should have attribute `logits`, but has {logits.keys()}" assert (
new_token = self._get_token_id(logits['logits']) "logits" in logits
), f"When first stage in GENERATE phase, the ouput should have attribute `logits`, but has {logits.keys()}"
new_token = self._get_token_id(logits["logits"])
self.mb_manager.step(new_token) self.mb_manager.step(new_token)
# If the current micro batch is not DONE, go through blocks # If the current micro batch is not DONE, go through blocks
if self.mb_manager.cur_state in (Status.GENERATE, Status.COOLDOWN): if self.mb_manager.cur_state in (Status.GENERATE, Status.COOLDOWN):
inputs_dict = self._prepare_inputs_for_new_token(new_token) inputs_dict = self._prepare_inputs_for_new_token(new_token)
interval_inputs = {'infer_state': self.mb_manager.cur_infer_state} interval_inputs = {"infer_state": self.mb_manager.cur_infer_state}
output_dict = model_forward(model, inputs_dict, interval_inputs) output_dict = model_forward(model, inputs_dict, interval_inputs)
else: else:
assert hidden_states is not None, "When not first stage, the hidden states should not be None" assert hidden_states is not None, "When not first stage, the hidden states should not be None"
@ -350,7 +414,10 @@ class GenerateSchedule(PipelineSchedule):
if self.mb_manager.cur_state is Status.PREFILL: if self.mb_manager.cur_state is Status.PREFILL:
inputs_dict = self.load_micro_batch() inputs_dict = self.load_micro_batch()
self.mb_manager.add_descrption(inputs_dict) self.mb_manager.add_descrption(inputs_dict)
interval_inputs = {'hidden_states': hidden_states['hidden_states'], 'infer_state': self.mb_manager.cur_infer_state} interval_inputs = {
"hidden_states": hidden_states["hidden_states"],
"infer_state": self.mb_manager.cur_infer_state,
}
output_dict = model_forward(model, inputs_dict, interval_inputs) output_dict = model_forward(model, inputs_dict, interval_inputs)
# Current microbatch is not DONE, send hidden_state to next stage # Current microbatch is not DONE, send hidden_state to next stage

View File

@ -65,6 +65,16 @@ def run_tp_pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch
torch.cuda.empty_cache() torch.cuda.empty_cache()
@parameterize("tp_size", [2])
@parameterize("pp_size", [1])
@parameterize("max_output_len", [2])
@parameterize("micro_batch_size", [1])
@clear_cache_before_run()
def run_tp_inference_test(tp_size, pp_size, max_output_len, micro_batch_size):
pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size)
torch.cuda.empty_cache()
def check_pipeline_inference(rank, world_size, port): def check_pipeline_inference(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_pipeline_inference_test() run_pipeline_inference_test()
@ -75,6 +85,11 @@ def check_tp_pipeline_inference(rank, world_size, port):
run_tp_pipeline_inference_test() run_tp_pipeline_inference_test()
def check_tp_inference(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_tp_inference_test()
@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5") @pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5")
@pytest.mark.dist @pytest.mark.dist
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
@ -82,6 +97,7 @@ def check_tp_pipeline_inference(rank, world_size, port):
def test_pipeline_inference(): def test_pipeline_inference():
spawn(check_pipeline_inference, nprocs=2) spawn(check_pipeline_inference, nprocs=2)
spawn(check_tp_pipeline_inference, nprocs=4) spawn(check_tp_pipeline_inference, nprocs=4)
spawn(check_tp_inference, nprocs=2)
if __name__ == "__main__": if __name__ == "__main__":