mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 17:17:05 +00:00
[shardformer] support bias_gelu_jit_fused for models (#5647)
* support gelu_bias_fused for gpt2 * support gelu_bias_fused for gpt2 fix fix fix * fix fix * fix
This commit is contained in:
@@ -1310,3 +1310,18 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
|
||||
)
|
||||
|
||||
return forward
|
||||
|
||||
|
||||
def get_jit_fused_gpt2_mlp_forward():
|
||||
from transformers.models.gpt2.modeling_gpt2 import GPT2MLP
|
||||
|
||||
from colossalai.kernel.jit.bias_gelu import GeLUFunction as JitGeLUFunction
|
||||
|
||||
def forward(self: GPT2MLP, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor:
|
||||
hidden_states, bias = self.c_fc(hidden_states)
|
||||
hidden_states = JitGeLUFunction.apply(hidden_states, bias)
|
||||
hidden_states = self.c_proj(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
return forward
|
||||
|
Reference in New Issue
Block a user