mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 19:13:01 +00:00
[shardformer] optimize seq parallelism (#6086)
* [shardformer] optimize seq parallelism * [shardformer] fix gpt2 fused linear col * [plugin] update gemini plugin * [plugin] update moe hybrid plugin * [test] update gpt2 fused linear test * [shardformer] fix gpt2 fused linear reduce
This commit is contained in:
@@ -322,7 +322,6 @@ class GeminiPlugin(DPPluginBase):
|
||||
enable_flash_attention (bool, optional): Whether to switch on flash attention in Shardformer. Defaults to False.
|
||||
enable_jit_fused (bool, optional): Whether to switch on JIT in Shardformer. Default to False.
|
||||
enable_sequence_parallelism (bool): Whether to turn on sequence parallelism in Shardformer. Defaults to False.
|
||||
enable_sequence_overlap (bool): Whether to turn on sequence overlap in Shardformer. Defaults to False.
|
||||
use_fp8 (bool, optional): Whether to enable fp8 mixed precision training. Defaults to False.
|
||||
verbose (bool, optional): verbose mode. Debug info including chunk search result will be printed. Defaults to False.
|
||||
fp8_communication (bool, optional): Whether to enable fp8 communication. Defaults to False.
|
||||
@@ -366,7 +365,6 @@ class GeminiPlugin(DPPluginBase):
|
||||
enable_flash_attention: bool = False,
|
||||
enable_sequence_parallelism: bool = False,
|
||||
enable_jit_fused: bool = False,
|
||||
enable_sequence_overlap: bool = False,
|
||||
enable_async_reduce: bool = True,
|
||||
use_fp8: bool = False,
|
||||
verbose: bool = False,
|
||||
@@ -428,7 +426,6 @@ class GeminiPlugin(DPPluginBase):
|
||||
self.enable_flash_attention = enable_flash_attention
|
||||
self.enable_sequence_parallelism = enable_sequence_parallelism if self.enable_tensor_parallelism else False
|
||||
self.enable_jit_fused = enable_jit_fused
|
||||
self.enable_sequence_overlap = enable_sequence_overlap
|
||||
self.verbose = verbose
|
||||
|
||||
self.tp_size = tp_size
|
||||
@@ -455,7 +452,6 @@ class GeminiPlugin(DPPluginBase):
|
||||
enable_flash_attention=self.enable_flash_attention,
|
||||
enable_jit_fused=self.enable_jit_fused,
|
||||
enable_sequence_parallelism=self.enable_sequence_parallelism,
|
||||
enable_sequence_overlap=self.enable_sequence_overlap,
|
||||
)
|
||||
|
||||
def __del__(self):
|
||||
|
Reference in New Issue
Block a user