diff --git a/colossalai/inference/hybridengine/engine.py b/colossalai/inference/hybridengine/engine.py index bb0b4c77a..8e46802dd 100644 --- a/colossalai/inference/hybridengine/engine.py +++ b/colossalai/inference/hybridengine/engine.py @@ -85,8 +85,6 @@ class CaiInferEngine: assert max_batch_size <= 64, "Max batch size exceeds the constraint" assert max_input_len + max_output_len <= 4096, "Max length exceeds the constraint" - # TODO: support only tensor parallel inference - assert pp_size > 1, "Not support only tensor parallel inference." self.pp_size = pp_size self.tp_size = tp_size @@ -102,23 +100,21 @@ class CaiInferEngine: # Init pg mesh pg_mesh = ProcessGroupMesh(pp_size, tp_size) - stage_manager = None - if pp_size > 1: - stage_manager = PipelineStageManager(pg_mesh, PP_AXIS, True) - self.cache_manager_list = [ - self._init_manager(model, max_batch_size, max_input_len, max_output_len) - for _ in range(micro_batch_buffer_size or pp_size) - ] - self.mb_manager = MicroBatchManager( - stage_manager.stage, - micro_batch_size, - micro_batch_buffer_size or pp_size, - max_input_len, - max_output_len, - self.cache_manager_list, - ) - self.verbose = verbose - self.schedule = GenerateSchedule(stage_manager, self.mb_manager, verbose) + stage_manager = PipelineStageManager(pg_mesh, PP_AXIS, True) + self.cache_manager_list = [ + self._init_manager(model, max_batch_size, max_input_len, max_output_len) + for _ in range(micro_batch_buffer_size or pp_size) + ] + self.mb_manager = MicroBatchManager( + stage_manager.stage, + micro_batch_size, + micro_batch_buffer_size or pp_size, + max_input_len, + max_output_len, + self.cache_manager_list, + ) + self.verbose = verbose + self.schedule = GenerateSchedule(stage_manager, self.mb_manager, verbose) self.model = self._shardformer(model, model_policy, stage_manager, pg_mesh.get_group_along_axis(TP_AXIS)) @@ -146,7 +142,7 @@ class CaiInferEngine: shardconfig = ShardConfig( tensor_parallel_process_group=tp_group, pipeline_stage_manager=stage_manager, - enable_tensor_parallelism=False, + enable_tensor_parallelism=True if self.tp_size > 1 else False, enable_fused_normalization=False, enable_all_optimization=False, enable_flash_attention=False, diff --git a/colossalai/inference/tensor_parallel/kvcache_manager.py b/colossalai/inference/tensor_parallel/kvcache_manager.py index 91bb96a1f..dda46a756 100644 --- a/colossalai/inference/tensor_parallel/kvcache_manager.py +++ b/colossalai/inference/tensor_parallel/kvcache_manager.py @@ -103,4 +103,4 @@ class MemoryManager: self.available_size = len(self.mem_state) self.mem_state[:] = 1 self.max_len_in_batch = 0 - self.logger.info("freed all space of memory manager") + # self.logger.info("freed all space of memory manager") diff --git a/colossalai/pipeline/schedule/generate.py b/colossalai/pipeline/schedule/generate.py index db02dab59..3ec0f97a7 100644 --- a/colossalai/pipeline/schedule/generate.py +++ b/colossalai/pipeline/schedule/generate.py @@ -93,9 +93,7 @@ class GenerateSchedule(PipelineSchedule): Returns: dict: inputs for interval stage, `{'past_key_values': torch.Tensor}` or `None` """ - model_inputs = { - 'infer_state': self.mb_manager.cur_descrption.infer_state - } + model_inputs = {"infer_state": self.mb_manager.cur_descrption.infer_state} return model_inputs def _prepare_inputs_for_new_token(self, new_token: torch.Tensor): @@ -129,8 +127,8 @@ class GenerateSchedule(PipelineSchedule): def _init_infer_state_action(self) -> None: """ - This action is only for no first stage, to load batch and init infer_state. - 1.Load micro_batch 2.Use the current micro_batch to init the current infer_state + This action is only for no first stage, to load batch and init infer_state. + 1.Load micro_batch 2.Use the current micro_batch to init the current infer_state """ inputs_dict = self.load_micro_batch() self.mb_manager.add_descrption(inputs_dict) @@ -145,19 +143,19 @@ class GenerateSchedule(PipelineSchedule): if self.verbose and self.stage_manager.is_first_stage(): torch.cuda.synchronize() self.timestamps[self.mb_manager.idx].append(time.time()) - interval_inputs = {'infer_state': self.mb_manager.cur_infer_state} + interval_inputs = {"infer_state": self.mb_manager.cur_infer_state} output_dict = model_forward(model, inputs_dict, interval_inputs) - self.action_interval_buffer.hidden_states = output_dict['hidden_states'] + self.action_interval_buffer.hidden_states = output_dict["hidden_states"] def _gen_token_action(self, model: Module): """ - This action is only for first stage + This action is only for first stage 1.do the forward with hidden_states to generate new tokens 2.step to update """ hidden_states = self.action_interval_buffer.hidden_states assert hidden_states is not None, "When first stage in GENERATE phase, the hidden states should not be None" - interval_inputs = {'hidden_states': hidden_states, 'infer_state': self.mb_manager.cur_infer_state} + interval_inputs = {"hidden_states": hidden_states, "infer_state": self.mb_manager.cur_infer_state} logits = model_forward(model, None, interval_inputs) if self.verbose and self.stage_manager.is_first_stage(): torch.cuda.synchronize() @@ -178,18 +176,18 @@ class GenerateSchedule(PipelineSchedule): new_token = self.action_interval_buffer.new_token assert new_token is not None, "When first stage in GENERATE phase, the new token should not be None" inputs_dict = self._prepare_inputs_for_new_token(new_token) - interval_inputs = {'infer_state': self.mb_manager.cur_infer_state} + interval_inputs = {"infer_state": self.mb_manager.cur_infer_state} output_dict = model_forward(model, inputs_dict, interval_inputs) - self.action_interval_buffer.hidden_states = output_dict['hidden_states'] + self.action_interval_buffer.hidden_states = output_dict["hidden_states"] def _body_encoding_action(self, model: Module): hidden_states = self.action_interval_buffer.hidden_states assert hidden_states is not None, "When not first stage, the hidden states should not be None" - interval_inputs = {'hidden_states': hidden_states, 'infer_state': self.mb_manager.cur_infer_state} + interval_inputs = {"hidden_states": hidden_states, "infer_state": self.mb_manager.cur_infer_state} output_dict = model_forward(model, None, interval_inputs) - self.action_interval_buffer.hidden_states = output_dict['hidden_states'] + self.action_interval_buffer.hidden_states = output_dict["hidden_states"] def _comm_action(self, recv_pre: bool) -> torch.Tensor: """ @@ -233,12 +231,73 @@ class GenerateSchedule(PipelineSchedule): return actions + def _gen_one_stage_action(self, model: Module): + """ + In this function, it will generate a sequence action for current state, and do the action one by one. + + Args: + model (Module): Model to be run. + + Returns: + List[Callable]: A list of action, each action is a callable function, and it will be called in order. + """ + actions = [] + + if self.mb_manager.cur_state is Status.PREFILL: + actions.append(partial(self._load_stage_action, model)) + elif self.mb_manager.cur_state is Status.GENERATE: + actions.append(partial(self._gen_token_action, model)) + actions.append(partial(self._head_encoding_action, model)) + elif self.mb_manager.cur_state is Status.COOLDOWN: + actions.append(partial(self._gen_token_action, model)) + + return actions + def generate_step(self, model: Module, data_iter: Iterable) -> Union[torch.Tensor, dict]: - if self.stage_manager.num_stages == 2: + if self.stage_manager.num_stages == 1: + return self.generate_step_one_stage(model, data_iter) + elif self.stage_manager.num_stages == 2: return self.generate_step_p2p(model, data_iter) else: return self.generate_step_broadcast(model, data_iter) + @torch.no_grad() + def generate_step_one_stage(self, model: Module, data_iter: Iterable) -> Union[torch.Tensor, dict]: + """ + Forward one step of the pipeline, when pipeline size is 1. + + Args: + model (Module): Model to be run. + data_iter (Iterable): Data iterator. + + Returns: + Union[torch.Tensor, dict]: The intermediate output (dict) of the current stage. If it is the last stage, the output is the loss (Tensor). + """ + output_sequence = [] + self.load_batch(data_iter) + model.eval() + self.comm_dtype = model.dtype + + whole_timestamp = [] + + # run by round + for _ in range(self.round): + self.timestamps = [[] for _ in range(self.stage_manager.num_stages)] if self.verbose else None + self.action_interval_buffer.clear() + while self.mb_manager.is_micro_batch_done() is False: + actions = self._gen_one_stage_action(model) + for action in actions: + action() + self.mb_manager.next() + # All microbatch in current round is DONE + output_sequence.extend(self.mb_manager.export_new_tokens()) + + self.mb_manager.clear() + if self.verbose: + whole_timestamp.extend(self.timestamps) + + return output_sequence, whole_timestamp + @torch.no_grad() def generate_step_p2p(self, model: Module, data_iter: Iterable) -> Union[torch.Tensor, dict]: """ @@ -319,7 +378,7 @@ class GenerateSchedule(PipelineSchedule): torch.cuda.synchronize() self.timestamps[self.mb_manager.idx].append(time.time()) self.mb_manager.add_descrption(inputs_dict) - interval_inputs = {'infer_state': self.mb_manager.cur_infer_state} + interval_inputs = {"infer_state": self.mb_manager.cur_infer_state} output_dict = model_forward(model, inputs_dict, interval_inputs) # In GENERATE phase else: @@ -330,18 +389,23 @@ class GenerateSchedule(PipelineSchedule): assert ( hidden_states is not None ), "When first stage in GENERATE phase, the hidden states should not be None" - interval_inputs = {'hidden_states': hidden_states['hidden_states'], 'infer_state': self.mb_manager.cur_infer_state} + interval_inputs = { + "hidden_states": hidden_states["hidden_states"], + "infer_state": self.mb_manager.cur_infer_state, + } logits = model_forward(model, None, interval_inputs) if self.verbose and self.stage_manager.is_first_stage(): torch.cuda.synchronize() self.timestamps[self.mb_manager.idx].append(time.time()) - assert 'logits' in logits, f"When first stage in GENERATE phase, the ouput should have attribute `logits`, but has {logits.keys()}" - new_token = self._get_token_id(logits['logits']) + assert ( + "logits" in logits + ), f"When first stage in GENERATE phase, the ouput should have attribute `logits`, but has {logits.keys()}" + new_token = self._get_token_id(logits["logits"]) self.mb_manager.step(new_token) # If the current micro batch is not DONE, go through blocks if self.mb_manager.cur_state in (Status.GENERATE, Status.COOLDOWN): inputs_dict = self._prepare_inputs_for_new_token(new_token) - interval_inputs = {'infer_state': self.mb_manager.cur_infer_state} + interval_inputs = {"infer_state": self.mb_manager.cur_infer_state} output_dict = model_forward(model, inputs_dict, interval_inputs) else: assert hidden_states is not None, "When not first stage, the hidden states should not be None" @@ -350,7 +414,10 @@ class GenerateSchedule(PipelineSchedule): if self.mb_manager.cur_state is Status.PREFILL: inputs_dict = self.load_micro_batch() self.mb_manager.add_descrption(inputs_dict) - interval_inputs = {'hidden_states': hidden_states['hidden_states'], 'infer_state': self.mb_manager.cur_infer_state} + interval_inputs = { + "hidden_states": hidden_states["hidden_states"], + "infer_state": self.mb_manager.cur_infer_state, + } output_dict = model_forward(model, inputs_dict, interval_inputs) # Current microbatch is not DONE, send hidden_state to next stage diff --git a/tests/test_infer/test_pipeline_infer.py b/tests/test_infer/test_hybrid_llama.py similarity index 83% rename from tests/test_infer/test_pipeline_infer.py rename to tests/test_infer/test_hybrid_llama.py index 3544153da..ca2349b18 100644 --- a/tests/test_infer/test_pipeline_infer.py +++ b/tests/test_infer/test_hybrid_llama.py @@ -65,6 +65,16 @@ def run_tp_pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch torch.cuda.empty_cache() +@parameterize("tp_size", [2]) +@parameterize("pp_size", [1]) +@parameterize("max_output_len", [2]) +@parameterize("micro_batch_size", [1]) +@clear_cache_before_run() +def run_tp_inference_test(tp_size, pp_size, max_output_len, micro_batch_size): + pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size) + torch.cuda.empty_cache() + + def check_pipeline_inference(rank, world_size, port): colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") run_pipeline_inference_test() @@ -75,6 +85,11 @@ def check_tp_pipeline_inference(rank, world_size, port): run_tp_pipeline_inference_test() +def check_tp_inference(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + run_tp_inference_test() + + @pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5") @pytest.mark.dist @rerun_if_address_is_in_use() @@ -82,6 +97,7 @@ def check_tp_pipeline_inference(rank, world_size, port): def test_pipeline_inference(): spawn(check_pipeline_inference, nprocs=2) spawn(check_tp_pipeline_inference, nprocs=4) + spawn(check_tp_inference, nprocs=2) if __name__ == "__main__":