mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 13:00:52 +00:00
[shardformer] Add overlap optional for HybridParallelPlugin (#4615)
* add optional overlap for plugin * remove fixed todo
This commit is contained in:
@@ -280,6 +280,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||
enable_flash_attention: bool = False,
|
||||
enable_jit_fused: bool = False,
|
||||
enable_sequence_parallelism: bool = False,
|
||||
enable_sequence_overlap: bool = False,
|
||||
num_microbatches: Optional[int] = None,
|
||||
microbatch_size: Optional[int] = None,
|
||||
initial_scale: float = 2**16,
|
||||
@@ -341,7 +342,8 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||
enable_fused_normalization=self.enable_fused_normalization,
|
||||
enable_flash_attention=self.enable_flash_attention,
|
||||
enable_jit_fused=self.enable_jit_fused,
|
||||
enable_sequence_parallelism=enable_sequence_parallelism)
|
||||
enable_sequence_parallelism=enable_sequence_parallelism,
|
||||
enable_sequence_overlap=enable_sequence_overlap)
|
||||
self.amp_config = dict(
|
||||
initial_scale=initial_scale,
|
||||
growth_factor=growth_factor,
|
||||
|
Reference in New Issue
Block a user