[pipeline] 1f1b schedule receive microbatch size (#4589)

This commit is contained in:
Hongxin Liu
2023-09-01 21:45:14 +08:00
committed by GitHub
parent 38ccb8b1a3
commit 508ca36fe3
3 changed files with 30 additions and 7 deletions

View File

@@ -247,6 +247,9 @@ class HybridParallelPlugin(PipelinePluginBase):
enable_flash_attention (bool, optional): Whether to switch on flash attention. Defaults to False.
enable_jit_fused (bool, optional): Whether to switch on JIT. Default to Falase.
num_microbatches (int, optional): Number of microbatches when using pipeline parallelism. Defaults to None.
microbatch_size (int, optional): Microbatch size when using pipeline parallelism.
Either ``num_microbatches`` or ``microbatch_size`` should be provided if using pipeline.
If ``num_microbatches`` is provided, this will be ignored. Defaults to None.
initial_scale (float, optional): The initial loss scale of AMP. Defaults to 2**16.
min_scale (float, optional): The minimum loss scale of AMP. Defaults to 1.
growth_factor (float, optional): The multiplication factor for increasing loss scale when using AMP. Defaults to 2.
@@ -278,6 +281,7 @@ class HybridParallelPlugin(PipelinePluginBase):
enable_jit_fused: bool = False,
enable_sequence_parallelism: bool = False,
num_microbatches: Optional[int] = None,
microbatch_size: Optional[int] = None,
initial_scale: float = 2**16,
min_scale: float = 1,
growth_factor: float = 2,
@@ -324,7 +328,9 @@ class HybridParallelPlugin(PipelinePluginBase):
assert num_microbatches is not None, 'num_microbatches must be specified when using pipeline parallelism'
assert self.zero_stage <= 1, 'zero stage must be 0 or 1 when using pipeline parallelism'
self.stage_manager = PipelineStageManager(self.pg_mesh, PP_AXIS)
self.schedule = OneForwardOneBackwardSchedule(num_microbatches, self.stage_manager)
self.schedule = OneForwardOneBackwardSchedule(self.stage_manager,
num_microbatches=num_microbatches,
microbatch_size=microbatch_size)
self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS)
self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS)
self.pp_group = self.pg_mesh.get_group_along_axis(PP_AXIS)