mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 09:07:51 +00:00
[shardformer] test all optimizations (#4399)
[shardformer] test all optimizations [shardformer] test all optimizations [shardformer] test all optimizations
This commit is contained in:
@@ -148,7 +148,10 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||
precision: str = 'fp16',
|
||||
zero_stage: int = 0,
|
||||
cpu_offload: bool = False,
|
||||
enable_all_optimization: bool = False,
|
||||
enable_fused_normalization: bool = False,
|
||||
enable_flash_attention: bool = False,
|
||||
enable_jit_fused: bool = False,
|
||||
num_microbatches: Optional[int] = None,
|
||||
initial_scale: float = 2**16,
|
||||
min_scale: float = 1,
|
||||
@@ -171,7 +174,10 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||
self.precision = precision
|
||||
self.zero_stage = zero_stage
|
||||
self.cpu_offload = cpu_offload
|
||||
self.enable_all_optimization = enable_all_optimization
|
||||
self.enable_fused_normalization = enable_fused_normalization
|
||||
self.enable_flash_attention = enable_flash_attention
|
||||
self.enable_jit_fused = enable_jit_fused
|
||||
self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size)
|
||||
self.stage_manager = None
|
||||
self.schedule = None
|
||||
@@ -186,7 +192,10 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||
self.shard_config = ShardConfig(tensor_parallel_process_group=self.tp_group,
|
||||
pipeline_stage_manager=self.stage_manager,
|
||||
enable_tensor_parallelism=self.tp_size > 1,
|
||||
enable_fused_normalization=self.enable_fused_normalization)
|
||||
enable_all_optimization=self.enable_all_optimization,
|
||||
enable_fused_normalization=self.enable_fused_normalization,
|
||||
enable_flash_attention=self.enable_flash_attention,
|
||||
enable_jit_fused=self.enable_jit_fused)
|
||||
self.amp_config = dict(
|
||||
initial_scale=initial_scale,
|
||||
growth_factor=growth_factor,
|
||||
|
Reference in New Issue
Block a user