[hotfix/hybridengine] fix bug when tp*pp size = 1 (#5069)

This commit is contained in:
Bin Jia
2023-11-20 17:15:37 +08:00
committed by GitHub
parent e5ce4c8ea6
commit 0c7d8bebd5
4 changed files with 64 additions and 14 deletions

View File

@@ -126,7 +126,7 @@ class CaiInferEngine:
# Init pg mesh
pg_mesh = ProcessGroupMesh(pp_size, tp_size)
stage_manager = PipelineStageManager(pg_mesh, PP_AXIS, True)
stage_manager = PipelineStageManager(pg_mesh, PP_AXIS, True if pp_size * tp_size > 1 else False)
self.cache_manager_list = [
self._init_manager(model, max_batch_size, max_input_len, max_output_len)
for _ in range(micro_batch_buffer_size or pp_size)
@@ -142,7 +142,9 @@ class CaiInferEngine:
self.verbose = verbose
self.schedule = GenerateSchedule(stage_manager, self.mb_manager, verbose)
self.model = self._shardformer(model, model_policy, stage_manager, pg_mesh.get_group_along_axis(TP_AXIS))
self.model = self._shardformer(
model, model_policy, stage_manager, pg_mesh.get_group_along_axis(TP_AXIS) if pp_size * tp_size > 1 else None
)
if quant == "gptq":
self.gptq_manager.post_init_gptq_buffer(self.model)