From 6af6d6fc9fe72997af44cdf3cb7b930a365ab915 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Mon, 29 Apr 2024 15:33:51 +0800 Subject: [PATCH] [shardformer] support bias_gelu_jit_fused for models (#5647) * support gelu_bias_fused for gpt2 * support gelu_bias_fused for gpt2 fix fix fix * fix fix * fix --- colossalai/shardformer/modeling/bert.py | 13 +++++++++++++ colossalai/shardformer/modeling/blip2.py | 14 ++++++++++++++ colossalai/shardformer/modeling/gpt2.py | 15 +++++++++++++++ colossalai/shardformer/modeling/vit.py | 12 ++++++++++++ colossalai/shardformer/policies/bert.py | 12 ++++++++++++ colossalai/shardformer/policies/blip2.py | 14 ++++++++++++++ colossalai/shardformer/policies/gpt2.py | 15 ++++++++++++++- colossalai/shardformer/policies/vit.py | 22 +++++++++++++++++++++- 8 files changed, 115 insertions(+), 2 deletions(-) diff --git a/colossalai/shardformer/modeling/bert.py b/colossalai/shardformer/modeling/bert.py index 0838fcee6..e7679f0ec 100644 --- a/colossalai/shardformer/modeling/bert.py +++ b/colossalai/shardformer/modeling/bert.py @@ -1287,3 +1287,16 @@ def bert_sequence_parallel_forward_fn(shard_config: ShardConfig): ) return forward + + +def get_jit_fused_bert_intermediate_forward(): + from transformers.models.bert.modeling_bert import BertIntermediate + + from colossalai.kernel.jit.bias_gelu import GeLUFunction as JitGeLUFunction + + def forward(self: BertIntermediate, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states, bias = self.dense(hidden_states) + hidden_states = JitGeLUFunction.apply(hidden_states, bias) + return hidden_states + + return forward diff --git a/colossalai/shardformer/modeling/blip2.py b/colossalai/shardformer/modeling/blip2.py index bd84c87c6..96e8a9d0c 100644 --- a/colossalai/shardformer/modeling/blip2.py +++ b/colossalai/shardformer/modeling/blip2.py @@ -129,3 +129,17 @@ def get_jit_fused_blip2_QFormer_output_forward(): return hidden_states return forward + + +def get_jit_fused_blip2_mlp_forward(): + from transformers.models.blip_2.modeling_blip_2 import Blip2MLP + + from colossalai.kernel.jit.bias_gelu import GeLUFunction as JitGeLUFunction + + def forward(self: Blip2MLP, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states, bias = self.fc1(hidden_states) + hidden_states = JitGeLUFunction.apply(hidden_states, bias) + hidden_states = self.fc2(hidden_states) + return hidden_states + + return forward diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py index 17acdf7fc..bfa995645 100644 --- a/colossalai/shardformer/modeling/gpt2.py +++ b/colossalai/shardformer/modeling/gpt2.py @@ -1310,3 +1310,18 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig): ) return forward + + +def get_jit_fused_gpt2_mlp_forward(): + from transformers.models.gpt2.modeling_gpt2 import GPT2MLP + + from colossalai.kernel.jit.bias_gelu import GeLUFunction as JitGeLUFunction + + def forward(self: GPT2MLP, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor: + hidden_states, bias = self.c_fc(hidden_states) + hidden_states = JitGeLUFunction.apply(hidden_states, bias) + hidden_states = self.c_proj(hidden_states) + hidden_states = self.dropout(hidden_states) + return hidden_states + + return forward diff --git a/colossalai/shardformer/modeling/vit.py b/colossalai/shardformer/modeling/vit.py index 67b10988d..b1a5c4143 100644 --- a/colossalai/shardformer/modeling/vit.py +++ b/colossalai/shardformer/modeling/vit.py @@ -372,3 +372,15 @@ def get_jit_fused_vit_output_forward(): return hidden_states return forward + + +def get_jit_fused_vit_intermediate_forward(): + from colossalai.kernel.jit.bias_gelu import GeLUFunction as JitGeLUFunction + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states, bias = self.dense(hidden_states) + hidden_states = JitGeLUFunction.apply(hidden_states, bias) + + return hidden_states + + return forward diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index d43fc893a..ad40e0e56 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -12,6 +12,7 @@ from ..modeling.bert import ( BertPipelineForwards, bert_sequence_parallel_forward_fn, get_bert_flash_attention_forward, + get_jit_fused_bert_intermediate_forward, get_jit_fused_bert_output_forward, get_jit_fused_bert_self_output_forward, ) @@ -38,11 +39,13 @@ class BertPolicy(Policy): def preprocess(self): self.tie_weight = self.tie_weight_check() + self.enable_bias_gelu_fused = self.shard_config.enable_jit_fused and self.model.config.hidden_act == "gelu" return self.model def module_policy(self): from transformers.models.bert.modeling_bert import ( BertEmbeddings, + BertIntermediate, BertLayer, BertModel, BertOutput, @@ -131,6 +134,7 @@ class BertPolicy(Policy): kwargs={ "seq_parallel_mode": sp_mode, "overlap": overlap, + "skip_bias_add": self.enable_bias_gelu_fused, }, ), SubModuleReplacementDescription( @@ -153,6 +157,14 @@ class BertPolicy(Policy): ), ] ) + if self.enable_bias_gelu_fused: + self.append_or_create_method_replacement( + description={ + "forward": get_jit_fused_bert_intermediate_forward(), + }, + policy=policy, + target_key=BertIntermediate, + ) if sp_mode == "split_gather": self.append_or_create_method_replacement( diff --git a/colossalai/shardformer/policies/blip2.py b/colossalai/shardformer/policies/blip2.py index b845e9336..9d1f6a306 100644 --- a/colossalai/shardformer/policies/blip2.py +++ b/colossalai/shardformer/policies/blip2.py @@ -3,6 +3,7 @@ import colossalai.shardformer.layer as col_nn from ..modeling.blip2 import ( forward_fn, get_blip2_flash_attention_forward, + get_jit_fused_blip2_mlp_forward, get_jit_fused_blip2_QFormer_output_forward, get_jit_fused_blip2_QFormer_self_output_forward, ) @@ -18,12 +19,16 @@ class BlipPolicy(Policy): def preprocess(self): self.tie_weight = self.tie_weight_check() + self.enable_bias_gelu_fused = ( + self.shard_config.enable_jit_fused and self.model.config.vision_config.hidden_act == "gelu" + ) return self.model def module_policy(self): from transformers.models.blip_2.modeling_blip_2 import ( Blip2Attention, Blip2EncoderLayer, + Blip2MLP, Blip2QFormerLayer, Blip2QFormerModel, Blip2QFormerOutput, @@ -73,6 +78,7 @@ class BlipPolicy(Policy): SubModuleReplacementDescription( suffix="mlp.fc1", target_module=col_nn.Linear1D_Col, + kwargs={"skip_bias_add": self.enable_bias_gelu_fused}, ), SubModuleReplacementDescription( suffix="mlp.fc2", @@ -201,6 +207,14 @@ class BlipPolicy(Policy): ) policy[Blip2Attention] = ModulePolicyDescription(method_replacement={"forward": forward_fn()}) + if self.enable_bias_gelu_fused: + self.append_or_create_method_replacement( + description={ + "forward": get_jit_fused_blip2_mlp_forward(), + }, + policy=policy, + target_key=Blip2MLP, + ) if embedding_cls is not None: self.append_or_create_submodule_replacement( diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index 6f4f835a8..531c2153b 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -10,6 +10,7 @@ from ..modeling.gpt2 import ( GPT2PipelineForwards, get_gpt2_flash_attention_forward, get_gpt_model_forward_for_flash_attn, + get_jit_fused_gpt2_mlp_forward, get_lm_forward_with_dist_cross_entropy, gpt2_sequence_parallel_forward_fn, ) @@ -36,10 +37,13 @@ class GPT2Policy(Policy): """ self.tie_weight = self.tie_weight_check() self.origin_attn_implement = self.model.config._attn_implementation + self.enable_bias_gelu_fused = ( + self.shard_config.enable_jit_fused and self.model.config.activation_function == "gelu" + ) return self.model def module_policy(self): - from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Block, GPT2Model + from transformers.models.gpt2.modeling_gpt2 import GPT2MLP, GPT2Attention, GPT2Block, GPT2Model ATTN_IMPLEMENTATION = { "eager": GPT2Attention, @@ -119,6 +123,7 @@ class GPT2Policy(Policy): "n_fused": 1, "seq_parallel_mode": sp_mode, "overlap": overlap, + "skip_bias_add": self.enable_bias_gelu_fused, }, ), SubModuleReplacementDescription( @@ -142,6 +147,14 @@ class GPT2Policy(Policy): ), ], ) + if self.enable_bias_gelu_fused: + self.append_or_create_method_replacement( + description={ + "forward": get_jit_fused_gpt2_mlp_forward(), + }, + policy=policy, + target_key=GPT2MLP, + ) if embedding_cls is not None: # padding vocabulary size when using pp to make it divisible by shard_config.make_vocab_size_divisible_by self.append_or_create_submodule_replacement( diff --git a/colossalai/shardformer/policies/vit.py b/colossalai/shardformer/policies/vit.py index 905398c4d..b7883af9f 100644 --- a/colossalai/shardformer/policies/vit.py +++ b/colossalai/shardformer/policies/vit.py @@ -11,6 +11,7 @@ from ..modeling.vit import ( ViTForImageClassification_pipeline_forward, ViTForMaskedImageModeling_pipeline_forward, ViTModel_pipeline_forward, + get_jit_fused_vit_intermediate_forward, get_jit_fused_vit_output_forward, get_vit_flash_self_attention_forward, ) @@ -24,10 +25,17 @@ class ViTPolicy(Policy): pass def preprocess(self): + self.enable_bias_gelu_fused = self.shard_config.enable_jit_fused and self.model.config.hidden_act == "gelu" return self.model def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: - from transformers.models.vit.modeling_vit import ViTEmbeddings, ViTLayer, ViTOutput, ViTSelfAttention + from transformers.models.vit.modeling_vit import ( + ViTEmbeddings, + ViTIntermediate, + ViTLayer, + ViTOutput, + ViTSelfAttention, + ) policy = {} @@ -83,6 +91,9 @@ class ViTPolicy(Policy): SubModuleReplacementDescription( suffix="intermediate.dense", target_module=col_nn.Linear1D_Col, + kwargs={ + "skip_bias_add": self.enable_bias_gelu_fused, + }, ), SubModuleReplacementDescription( suffix="output.dense", @@ -94,6 +105,14 @@ class ViTPolicy(Policy): ), ], ) + if self.enable_bias_gelu_fused: + self.append_or_create_method_replacement( + description={ + "forward": get_jit_fused_vit_intermediate_forward(), + }, + policy=policy, + target_key=ViTIntermediate, + ) # use flash attention if self.shard_config.enable_flash_attention: @@ -115,6 +134,7 @@ class ViTPolicy(Policy): policy=policy, target_key=ViTOutput, ) + return policy def new_model_class(self):