mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-16 22:52:25 +00:00
[hotfix/hybridengine] fix bug when tp*pp size = 1 (#5069)
This commit is contained in:
@@ -126,7 +126,7 @@ class CaiInferEngine:
|
||||
# Init pg mesh
|
||||
pg_mesh = ProcessGroupMesh(pp_size, tp_size)
|
||||
|
||||
stage_manager = PipelineStageManager(pg_mesh, PP_AXIS, True)
|
||||
stage_manager = PipelineStageManager(pg_mesh, PP_AXIS, True if pp_size * tp_size > 1 else False)
|
||||
self.cache_manager_list = [
|
||||
self._init_manager(model, max_batch_size, max_input_len, max_output_len)
|
||||
for _ in range(micro_batch_buffer_size or pp_size)
|
||||
@@ -142,7 +142,9 @@ class CaiInferEngine:
|
||||
self.verbose = verbose
|
||||
self.schedule = GenerateSchedule(stage_manager, self.mb_manager, verbose)
|
||||
|
||||
self.model = self._shardformer(model, model_policy, stage_manager, pg_mesh.get_group_along_axis(TP_AXIS))
|
||||
self.model = self._shardformer(
|
||||
model, model_policy, stage_manager, pg_mesh.get_group_along_axis(TP_AXIS) if pp_size * tp_size > 1 else None
|
||||
)
|
||||
if quant == "gptq":
|
||||
self.gptq_manager.post_init_gptq_buffer(self.model)
|
||||
|
||||
|
Reference in New Issue
Block a user