mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-27 04:33:04 +00:00
[shardformer] support bias_gelu_jit_fused for models (#5647)
* support gelu_bias_fused for gpt2 * support gelu_bias_fused for gpt2 fix fix fix * fix fix * fix
This commit is contained in:
@@ -12,6 +12,7 @@ from ..modeling.bert import (
|
||||
BertPipelineForwards,
|
||||
bert_sequence_parallel_forward_fn,
|
||||
get_bert_flash_attention_forward,
|
||||
get_jit_fused_bert_intermediate_forward,
|
||||
get_jit_fused_bert_output_forward,
|
||||
get_jit_fused_bert_self_output_forward,
|
||||
)
|
||||
@@ -38,11 +39,13 @@ class BertPolicy(Policy):
|
||||
|
||||
def preprocess(self):
|
||||
self.tie_weight = self.tie_weight_check()
|
||||
self.enable_bias_gelu_fused = self.shard_config.enable_jit_fused and self.model.config.hidden_act == "gelu"
|
||||
return self.model
|
||||
|
||||
def module_policy(self):
|
||||
from transformers.models.bert.modeling_bert import (
|
||||
BertEmbeddings,
|
||||
BertIntermediate,
|
||||
BertLayer,
|
||||
BertModel,
|
||||
BertOutput,
|
||||
@@ -131,6 +134,7 @@ class BertPolicy(Policy):
|
||||
kwargs={
|
||||
"seq_parallel_mode": sp_mode,
|
||||
"overlap": overlap,
|
||||
"skip_bias_add": self.enable_bias_gelu_fused,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
@@ -153,6 +157,14 @@ class BertPolicy(Policy):
|
||||
),
|
||||
]
|
||||
)
|
||||
if self.enable_bias_gelu_fused:
|
||||
self.append_or_create_method_replacement(
|
||||
description={
|
||||
"forward": get_jit_fused_bert_intermediate_forward(),
|
||||
},
|
||||
policy=policy,
|
||||
target_key=BertIntermediate,
|
||||
)
|
||||
|
||||
if sp_mode == "split_gather":
|
||||
self.append_or_create_method_replacement(
|
||||
|
Reference in New Issue
Block a user