mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-18 16:00:49 +00:00
[shardformer] optimize seq parallelism (#6086)
* [shardformer] optimize seq parallelism * [shardformer] fix gpt2 fused linear col * [plugin] update gemini plugin * [plugin] update moe hybrid plugin * [test] update gpt2 fused linear test * [shardformer] fix gpt2 fused linear reduce
This commit is contained in:
@@ -322,7 +322,6 @@ class GeminiPlugin(DPPluginBase):
|
||||
enable_flash_attention (bool, optional): Whether to switch on flash attention in Shardformer. Defaults to False.
|
||||
enable_jit_fused (bool, optional): Whether to switch on JIT in Shardformer. Default to False.
|
||||
enable_sequence_parallelism (bool): Whether to turn on sequence parallelism in Shardformer. Defaults to False.
|
||||
enable_sequence_overlap (bool): Whether to turn on sequence overlap in Shardformer. Defaults to False.
|
||||
use_fp8 (bool, optional): Whether to enable fp8 mixed precision training. Defaults to False.
|
||||
verbose (bool, optional): verbose mode. Debug info including chunk search result will be printed. Defaults to False.
|
||||
fp8_communication (bool, optional): Whether to enable fp8 communication. Defaults to False.
|
||||
@@ -366,7 +365,6 @@ class GeminiPlugin(DPPluginBase):
|
||||
enable_flash_attention: bool = False,
|
||||
enable_sequence_parallelism: bool = False,
|
||||
enable_jit_fused: bool = False,
|
||||
enable_sequence_overlap: bool = False,
|
||||
enable_async_reduce: bool = True,
|
||||
use_fp8: bool = False,
|
||||
verbose: bool = False,
|
||||
@@ -428,7 +426,6 @@ class GeminiPlugin(DPPluginBase):
|
||||
self.enable_flash_attention = enable_flash_attention
|
||||
self.enable_sequence_parallelism = enable_sequence_parallelism if self.enable_tensor_parallelism else False
|
||||
self.enable_jit_fused = enable_jit_fused
|
||||
self.enable_sequence_overlap = enable_sequence_overlap
|
||||
self.verbose = verbose
|
||||
|
||||
self.tp_size = tp_size
|
||||
@@ -455,7 +452,6 @@ class GeminiPlugin(DPPluginBase):
|
||||
enable_flash_attention=self.enable_flash_attention,
|
||||
enable_jit_fused=self.enable_jit_fused,
|
||||
enable_sequence_parallelism=self.enable_sequence_parallelism,
|
||||
enable_sequence_overlap=self.enable_sequence_overlap,
|
||||
)
|
||||
|
||||
def __del__(self):
|
||||
|
@@ -951,7 +951,6 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||
enable_jit_fused (bool, optional): Whether to switch on JIT in Shardformer. Default to False.
|
||||
enable_sequence_parallelism (bool): Whether to turn on sequence parallelism in Shardformer. Defaults to False.
|
||||
sequence_parallelism_mode (str): The Sequence parallelism mode. Can only be choosed from ["split_gather", "ring", "all_to_all"]. Defaults to "split_gather".
|
||||
enable_sequence_overlap (bool): Whether to turn on sequence overlap in Shardformer. Defaults to False.
|
||||
parallel_output (bool): Whether to keep the output parallel when enabling tensor parallelism. Default to True.
|
||||
num_microbatches (int, optional): Number of microbatches when using pipeline parallelism. Defaults to None.
|
||||
microbatch_size (int, optional): Microbatch size when using pipeline parallelism.
|
||||
@@ -1002,7 +1001,6 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||
enable_jit_fused: bool = False,
|
||||
enable_sequence_parallelism: bool = False,
|
||||
sequence_parallelism_mode: str = None,
|
||||
enable_sequence_overlap: bool = False,
|
||||
parallel_output: bool = True,
|
||||
num_microbatches: Optional[int] = None,
|
||||
microbatch_size: Optional[int] = None,
|
||||
@@ -1174,7 +1172,6 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||
enable_jit_fused=self.enable_jit_fused,
|
||||
enable_sequence_parallelism=enable_sequence_parallelism,
|
||||
sequence_parallelism_mode=sequence_parallelism_mode,
|
||||
enable_sequence_overlap=enable_sequence_overlap,
|
||||
parallel_output=parallel_output,
|
||||
make_vocab_size_divisible_by=make_vocab_size_divisible_by,
|
||||
gradient_checkpoint_config=gradient_checkpoint_config,
|
||||
|
@@ -140,7 +140,6 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
||||
enable_jit_fused (bool, optional): Whether to switch on JIT in Shardformer. Default to False.
|
||||
enable_sequence_parallelism (bool): Whether to turn on sequence parallelism in Shardformer. Defaults to False.
|
||||
sequence_parallelism_mode (str): The Sequence parallelism mode. Can only be choosed from ["split_gather", "ring", "all_to_all"]. Defaults to "split_gather".
|
||||
enable_sequence_overlap (bool): Whether to turn on sequence overlap in Shardformer. Defaults to False.
|
||||
parallel_output (bool): Whether to keep the output parallel when enabling tensor parallelism. Default to True.
|
||||
num_microbatches (int, optional): Number of microbatches when using pipeline parallelism. Defaults to None.
|
||||
microbatch_size (int, optional): Microbatch size when using pipeline parallelism.
|
||||
@@ -189,7 +188,6 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
||||
enable_jit_fused: bool = False,
|
||||
enable_sequence_parallelism: bool = False,
|
||||
sequence_parallelism_mode: str = None,
|
||||
enable_sequence_overlap: bool = False,
|
||||
parallel_output: bool = True,
|
||||
num_microbatches: Optional[int] = None,
|
||||
microbatch_size: Optional[int] = None,
|
||||
@@ -351,7 +349,6 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
||||
enable_jit_fused=self.enable_jit_fused,
|
||||
enable_sequence_parallelism=enable_sequence_parallelism,
|
||||
sequence_parallelism_mode=sequence_parallelism_mode,
|
||||
enable_sequence_overlap=enable_sequence_overlap,
|
||||
parallel_output=parallel_output,
|
||||
make_vocab_size_divisible_by=make_vocab_size_divisible_by,
|
||||
gradient_checkpoint_config=gradient_checkpoint_config,
|
||||
|
Reference in New Issue
Block a user