[shardformer] Add overlap optional for HybridParallelPlugin (#4615)

* add optional overlap for plugin

* remove fixed todo
This commit is contained in:
Bin Jia
2023-09-05 11:52:23 +08:00
committed by GitHub
parent a39a5c66fe
commit 86d22581e4
2 changed files with 3 additions and 3 deletions

View File

@@ -280,6 +280,7 @@ class HybridParallelPlugin(PipelinePluginBase):
enable_flash_attention: bool = False,
enable_jit_fused: bool = False,
enable_sequence_parallelism: bool = False,
enable_sequence_overlap: bool = False,
num_microbatches: Optional[int] = None,
microbatch_size: Optional[int] = None,
initial_scale: float = 2**16,
@@ -341,7 +342,8 @@ class HybridParallelPlugin(PipelinePluginBase):
enable_fused_normalization=self.enable_fused_normalization,
enable_flash_attention=self.enable_flash_attention,
enable_jit_fused=self.enable_jit_fused,
enable_sequence_parallelism=enable_sequence_parallelism)
enable_sequence_parallelism=enable_sequence_parallelism,
enable_sequence_overlap=enable_sequence_overlap)
self.amp_config = dict(
initial_scale=initial_scale,
growth_factor=growth_factor,