[shardformer] optimize seq parallelism (#6086)

* [shardformer] optimize seq parallelism

* [shardformer] fix gpt2 fused linear col

* [plugin] update gemini plugin

* [plugin] update moe hybrid plugin

* [test] update gpt2 fused linear test

* [shardformer] fix gpt2 fused linear reduce
This commit is contained in:
Hongxin Liu
2024-10-11 13:44:40 +08:00
committed by GitHub
parent 6b2c506fc5
commit dc2cdaf3e8
13 changed files with 111 additions and 278 deletions

View File

@@ -322,7 +322,6 @@ class GeminiPlugin(DPPluginBase):
enable_flash_attention (bool, optional): Whether to switch on flash attention in Shardformer. Defaults to False.
enable_jit_fused (bool, optional): Whether to switch on JIT in Shardformer. Default to False.
enable_sequence_parallelism (bool): Whether to turn on sequence parallelism in Shardformer. Defaults to False.
enable_sequence_overlap (bool): Whether to turn on sequence overlap in Shardformer. Defaults to False.
use_fp8 (bool, optional): Whether to enable fp8 mixed precision training. Defaults to False.
verbose (bool, optional): verbose mode. Debug info including chunk search result will be printed. Defaults to False.
fp8_communication (bool, optional): Whether to enable fp8 communication. Defaults to False.
@@ -366,7 +365,6 @@ class GeminiPlugin(DPPluginBase):
enable_flash_attention: bool = False,
enable_sequence_parallelism: bool = False,
enable_jit_fused: bool = False,
enable_sequence_overlap: bool = False,
enable_async_reduce: bool = True,
use_fp8: bool = False,
verbose: bool = False,
@@ -428,7 +426,6 @@ class GeminiPlugin(DPPluginBase):
self.enable_flash_attention = enable_flash_attention
self.enable_sequence_parallelism = enable_sequence_parallelism if self.enable_tensor_parallelism else False
self.enable_jit_fused = enable_jit_fused
self.enable_sequence_overlap = enable_sequence_overlap
self.verbose = verbose
self.tp_size = tp_size
@@ -455,7 +452,6 @@ class GeminiPlugin(DPPluginBase):
enable_flash_attention=self.enable_flash_attention,
enable_jit_fused=self.enable_jit_fused,
enable_sequence_parallelism=self.enable_sequence_parallelism,
enable_sequence_overlap=self.enable_sequence_overlap,
)
def __del__(self):