[shardformer] test all optimizations (#4399)

[shardformer] test all optimizations

[shardformer] test all optimizations

[shardformer] test all optimizations
This commit is contained in:
flybird1111
2023-08-10 13:59:30 +08:00
committed by Hongxin Liu
parent 7a3dfd0c64
commit d2cd48e0be
4 changed files with 59 additions and 29 deletions

View File

@@ -148,7 +148,10 @@ class HybridParallelPlugin(PipelinePluginBase):
precision: str = 'fp16',
zero_stage: int = 0,
cpu_offload: bool = False,
enable_all_optimization: bool = False,
enable_fused_normalization: bool = False,
enable_flash_attention: bool = False,
enable_jit_fused: bool = False,
num_microbatches: Optional[int] = None,
initial_scale: float = 2**16,
min_scale: float = 1,
@@ -171,7 +174,10 @@ class HybridParallelPlugin(PipelinePluginBase):
self.precision = precision
self.zero_stage = zero_stage
self.cpu_offload = cpu_offload
self.enable_all_optimization = enable_all_optimization
self.enable_fused_normalization = enable_fused_normalization
self.enable_flash_attention = enable_flash_attention
self.enable_jit_fused = enable_jit_fused
self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size)
self.stage_manager = None
self.schedule = None
@@ -186,7 +192,10 @@ class HybridParallelPlugin(PipelinePluginBase):
self.shard_config = ShardConfig(tensor_parallel_process_group=self.tp_group,
pipeline_stage_manager=self.stage_manager,
enable_tensor_parallelism=self.tp_size > 1,
enable_fused_normalization=self.enable_fused_normalization)
enable_all_optimization=self.enable_all_optimization,
enable_fused_normalization=self.enable_fused_normalization,
enable_flash_attention=self.enable_flash_attention,
enable_jit_fused=self.enable_jit_fused)
self.amp_config = dict(
initial_scale=initial_scale,
growth_factor=growth_factor,