[shardformer] support bias_gelu_jit_fused for models (#5647)

* support gelu_bias_fused for gpt2

* support gelu_bias_fused for gpt2

fix

fix

fix

* fix

fix

* fix
This commit is contained in:
flybird11111
2024-04-29 15:33:51 +08:00
committed by GitHub
parent 7f8b16635b
commit 6af6d6fc9f
8 changed files with 115 additions and 2 deletions

View File

@@ -1310,3 +1310,18 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
)
return forward
def get_jit_fused_gpt2_mlp_forward():
from transformers.models.gpt2.modeling_gpt2 import GPT2MLP
from colossalai.kernel.jit.bias_gelu import GeLUFunction as JitGeLUFunction
def forward(self: GPT2MLP, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor:
hidden_states, bias = self.c_fc(hidden_states)
hidden_states = JitGeLUFunction.apply(hidden_states, bias)
hidden_states = self.c_proj(hidden_states)
hidden_states = self.dropout(hidden_states)
return hidden_states
return forward