mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 09:38:05 +00:00
[pipeline] 1f1b schedule receive microbatch size (#4589)
This commit is contained in:
@@ -247,6 +247,9 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||
enable_flash_attention (bool, optional): Whether to switch on flash attention. Defaults to False.
|
||||
enable_jit_fused (bool, optional): Whether to switch on JIT. Default to Falase.
|
||||
num_microbatches (int, optional): Number of microbatches when using pipeline parallelism. Defaults to None.
|
||||
microbatch_size (int, optional): Microbatch size when using pipeline parallelism.
|
||||
Either ``num_microbatches`` or ``microbatch_size`` should be provided if using pipeline.
|
||||
If ``num_microbatches`` is provided, this will be ignored. Defaults to None.
|
||||
initial_scale (float, optional): The initial loss scale of AMP. Defaults to 2**16.
|
||||
min_scale (float, optional): The minimum loss scale of AMP. Defaults to 1.
|
||||
growth_factor (float, optional): The multiplication factor for increasing loss scale when using AMP. Defaults to 2.
|
||||
@@ -278,6 +281,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||
enable_jit_fused: bool = False,
|
||||
enable_sequence_parallelism: bool = False,
|
||||
num_microbatches: Optional[int] = None,
|
||||
microbatch_size: Optional[int] = None,
|
||||
initial_scale: float = 2**16,
|
||||
min_scale: float = 1,
|
||||
growth_factor: float = 2,
|
||||
@@ -324,7 +328,9 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||
assert num_microbatches is not None, 'num_microbatches must be specified when using pipeline parallelism'
|
||||
assert self.zero_stage <= 1, 'zero stage must be 0 or 1 when using pipeline parallelism'
|
||||
self.stage_manager = PipelineStageManager(self.pg_mesh, PP_AXIS)
|
||||
self.schedule = OneForwardOneBackwardSchedule(num_microbatches, self.stage_manager)
|
||||
self.schedule = OneForwardOneBackwardSchedule(self.stage_manager,
|
||||
num_microbatches=num_microbatches,
|
||||
microbatch_size=microbatch_size)
|
||||
self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS)
|
||||
self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS)
|
||||
self.pp_group = self.pg_mesh.get_group_along_axis(PP_AXIS)
|
||||
|
Reference in New Issue
Block a user