mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-28 04:55:25 +00:00
[shardformer/sequence parallel] Cherry pick commit to new branch (#4450)
* [shardformer/sequence parallel] Support sequence parallel for gpt2 (#4384) * [sequence parallel] add sequence parallel linear col/row support (#4336) * add sequence parallel linear col/row support * add annotation * add annotation * add support for gpt2 fused qkv linear layer * support sequence parallel in GPT2 * add docstring and note * add requirments * remove unused flash-attb * modify flash attn test * modify flash attn setting * modify flash attn code * add assert before divide, rename forward function * [shardformer/test] fix gpt2 test with seq-parallel * [shardformer/sequence parallel] Overlap input gather and grad computation during col backward (#4401) * overlap gather input / grad computing during col backward * modify test for overlap * simplify code * fix code and modify cuda stream synchronize * [shardformer/sequence parallel] polish code
This commit is contained in:
@@ -7,6 +7,7 @@ import colossalai.shardformer.layer as col_nn
|
||||
|
||||
from .._utils import getattr_, setattr_
|
||||
from ..modeling.gpt2 import GPT2PipelineForwards, get_gpt2_flash_attention_forward
|
||||
from ..modeling.gpt2_seq import gpt2_sequence_parallel_forward_fn
|
||||
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||
|
||||
__all__ = [
|
||||
@@ -49,6 +50,9 @@ class GPT2Policy(Policy):
|
||||
target_module=col_nn.DropoutForParallelInput,
|
||||
),
|
||||
])
|
||||
if self.shard_config.enable_sequence_parallelism:
|
||||
policy[GPT2Model].method_replacement = {"forward": gpt2_sequence_parallel_forward_fn(self.shard_config)}
|
||||
|
||||
policy[GPT2Block] = ModulePolicyDescription(attribute_replacement={
|
||||
"attn.embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
||||
"attn.split_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
||||
@@ -120,6 +124,11 @@ class GPT2Policy(Policy):
|
||||
policy[GPT2Attention] = ModulePolicyDescription(method_replacement={
|
||||
'forward': get_gpt2_flash_attention_forward(),
|
||||
})
|
||||
|
||||
if self.shard_config.enable_sequence_parallelism:
|
||||
suffix_list = ["attn.c_attn", "attn.c_proj", "mlp.c_fc", "mlp.c_proj"]
|
||||
self.append_seq_parallel_to_policy(suffix_list=suffix_list, module_policy_description=policy[GPT2Block])
|
||||
|
||||
return policy
|
||||
|
||||
def postprocess(self):
|
||||
|
Reference in New Issue
Block a user