[shardformer/sequence parallel] Cherry pick commit to new branch (#4450)

* [shardformer/sequence parallel] Support sequence parallel for gpt2 (#4384)

* [sequence parallel] add sequence parallel linear col/row support (#4336)

* add sequence parallel linear col/row support

* add annotation

* add annotation

* add support for gpt2 fused qkv linear layer

* support sequence parallel in GPT2

* add docstring and note

* add requirments

* remove unused flash-attb

* modify flash attn test

* modify flash attn setting

* modify flash attn code

* add assert before divide, rename forward function

* [shardformer/test] fix gpt2 test with seq-parallel

* [shardformer/sequence parallel] Overlap input gather and grad computation during col backward (#4401)

* overlap gather input / grad computing during col backward

* modify test for overlap

* simplify code

* fix code and modify cuda stream synchronize

* [shardformer/sequence parallel] polish code
This commit is contained in:
Bin Jia
2023-08-16 15:41:20 +08:00
committed by GitHub
parent d20dceb9a3
commit 424629fea0
12 changed files with 655 additions and 65 deletions

View File

@@ -152,6 +152,7 @@ class HybridParallelPlugin(PipelinePluginBase):
enable_fused_normalization: bool = False,
enable_flash_attention: bool = False,
enable_jit_fused: bool = False,
enable_sequence_parallelism: bool = False,
num_microbatches: Optional[int] = None,
initial_scale: float = 2**16,
min_scale: float = 1,
@@ -178,6 +179,7 @@ class HybridParallelPlugin(PipelinePluginBase):
self.enable_fused_normalization = enable_fused_normalization
self.enable_flash_attention = enable_flash_attention
self.enable_jit_fused = enable_jit_fused
self.enable_sequence_parallelism = enable_sequence_parallelism
self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size)
self.stage_manager = None
self.schedule = None
@@ -195,7 +197,8 @@ class HybridParallelPlugin(PipelinePluginBase):
enable_all_optimization=self.enable_all_optimization,
enable_fused_normalization=self.enable_fused_normalization,
enable_flash_attention=self.enable_flash_attention,
enable_jit_fused=self.enable_jit_fused)
enable_jit_fused=self.enable_jit_fused,
enable_sequence_parallelism=enable_sequence_parallelism)
self.amp_config = dict(
initial_scale=initial_scale,
growth_factor=growth_factor,