mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 10:06:44 +00:00
[shardformer/sequence parallel] Cherry pick commit to new branch (#4450)
* [shardformer/sequence parallel] Support sequence parallel for gpt2 (#4384) * [sequence parallel] add sequence parallel linear col/row support (#4336) * add sequence parallel linear col/row support * add annotation * add annotation * add support for gpt2 fused qkv linear layer * support sequence parallel in GPT2 * add docstring and note * add requirments * remove unused flash-attb * modify flash attn test * modify flash attn setting * modify flash attn code * add assert before divide, rename forward function * [shardformer/test] fix gpt2 test with seq-parallel * [shardformer/sequence parallel] Overlap input gather and grad computation during col backward (#4401) * overlap gather input / grad computing during col backward * modify test for overlap * simplify code * fix code and modify cuda stream synchronize * [shardformer/sequence parallel] polish code
This commit is contained in:
@@ -152,6 +152,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||
enable_fused_normalization: bool = False,
|
||||
enable_flash_attention: bool = False,
|
||||
enable_jit_fused: bool = False,
|
||||
enable_sequence_parallelism: bool = False,
|
||||
num_microbatches: Optional[int] = None,
|
||||
initial_scale: float = 2**16,
|
||||
min_scale: float = 1,
|
||||
@@ -178,6 +179,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||
self.enable_fused_normalization = enable_fused_normalization
|
||||
self.enable_flash_attention = enable_flash_attention
|
||||
self.enable_jit_fused = enable_jit_fused
|
||||
self.enable_sequence_parallelism = enable_sequence_parallelism
|
||||
self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size)
|
||||
self.stage_manager = None
|
||||
self.schedule = None
|
||||
@@ -195,7 +197,8 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||
enable_all_optimization=self.enable_all_optimization,
|
||||
enable_fused_normalization=self.enable_fused_normalization,
|
||||
enable_flash_attention=self.enable_flash_attention,
|
||||
enable_jit_fused=self.enable_jit_fused)
|
||||
enable_jit_fused=self.enable_jit_fused,
|
||||
enable_sequence_parallelism=enable_sequence_parallelism)
|
||||
self.amp_config = dict(
|
||||
initial_scale=initial_scale,
|
||||
growth_factor=growth_factor,
|
||||
|
Reference in New Issue
Block a user