[shardformer] support bias_gelu_jit_fused for models (#5647)

* support gelu_bias_fused for gpt2

* support gelu_bias_fused for gpt2

fix

fix

fix

* fix

fix

* fix
This commit is contained in:
flybird11111
2024-04-29 15:33:51 +08:00
committed by GitHub
parent 7f8b16635b
commit 6af6d6fc9f
8 changed files with 115 additions and 2 deletions

View File

@@ -12,6 +12,7 @@ from ..modeling.bert import (
BertPipelineForwards,
bert_sequence_parallel_forward_fn,
get_bert_flash_attention_forward,
get_jit_fused_bert_intermediate_forward,
get_jit_fused_bert_output_forward,
get_jit_fused_bert_self_output_forward,
)
@@ -38,11 +39,13 @@ class BertPolicy(Policy):
def preprocess(self):
self.tie_weight = self.tie_weight_check()
self.enable_bias_gelu_fused = self.shard_config.enable_jit_fused and self.model.config.hidden_act == "gelu"
return self.model
def module_policy(self):
from transformers.models.bert.modeling_bert import (
BertEmbeddings,
BertIntermediate,
BertLayer,
BertModel,
BertOutput,
@@ -131,6 +134,7 @@ class BertPolicy(Policy):
kwargs={
"seq_parallel_mode": sp_mode,
"overlap": overlap,
"skip_bias_add": self.enable_bias_gelu_fused,
},
),
SubModuleReplacementDescription(
@@ -153,6 +157,14 @@ class BertPolicy(Policy):
),
]
)
if self.enable_bias_gelu_fused:
self.append_or_create_method_replacement(
description={
"forward": get_jit_fused_bert_intermediate_forward(),
},
policy=policy,
target_key=BertIntermediate,
)
if sp_mode == "split_gather":
self.append_or_create_method_replacement(