mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-04 01:29:55 +00:00
[hotfix]fix chatglm rmsnorm (#5079)
* fix chatglm * add error * fix bugs --------- Co-authored-by: CjhHa1 <cjh18671720497outlook.com>
This commit is contained in:
parent
67a07e6f64
commit
cb450c2861
@ -408,12 +408,20 @@ class ChatGLM2InferenceForwards:
|
||||
query_layer = query_layer.reshape(
|
||||
-1, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head
|
||||
)
|
||||
key_layer = key_layer.reshape(
|
||||
-1, self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head
|
||||
)
|
||||
value_layer = value_layer.reshape(
|
||||
-1, self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head
|
||||
)
|
||||
if self.multi_query_attention:
|
||||
key_layer = key_layer.reshape(
|
||||
-1, self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head
|
||||
)
|
||||
value_layer = value_layer.reshape(
|
||||
-1, self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head
|
||||
)
|
||||
else:
|
||||
key_layer = key_layer.reshape(
|
||||
-1, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head
|
||||
)
|
||||
value_layer = value_layer.reshape(
|
||||
-1, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head
|
||||
)
|
||||
|
||||
if infer_state.is_context_stage:
|
||||
# first token generation:
|
||||
|
@ -1,5 +1,7 @@
|
||||
from functools import partial
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import (
|
||||
@ -7,23 +9,40 @@ from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import (
|
||||
ChatGLMModel,
|
||||
GLMBlock,
|
||||
GLMTransformer,
|
||||
RMSNorm,
|
||||
SelfAttention,
|
||||
)
|
||||
|
||||
# import colossalai
|
||||
from colossalai.shardformer.policies.chatglm2 import ChatGLMModelPolicy
|
||||
from colossalai.shardformer.policies.chatglm2 import ChatGLMForConditionalGenerationPolicy
|
||||
|
||||
from ..modeling._utils import init_to_get_rotary
|
||||
from ..modeling.chatglm2 import ChatGLM2InferenceForwards
|
||||
|
||||
try:
|
||||
from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward as lightllm_rmsnorm_forward
|
||||
|
||||
HAS_TRITON_RMSNORM = True
|
||||
except:
|
||||
print("you should install triton from https://github.com/openai/triton")
|
||||
print("Did not find rms-norm triton kernel")
|
||||
print(
|
||||
"You can use the following command to install: pip install git+https://github.com/ModelTC/lightllm.git@ece7b43f8a6dfa74027adc77c2c176cff28c76c8"
|
||||
)
|
||||
HAS_TRITON_RMSNORM = False
|
||||
|
||||
|
||||
class ChatGLM2InferPolicy(ChatGLMModelPolicy):
|
||||
def get_triton_rmsnorm_forward():
|
||||
if HAS_TRITON_RMSNORM:
|
||||
|
||||
def _triton_rmsnorm_forward(self: RMSNorm, hidden_states: torch.Tensor):
|
||||
return lightllm_rmsnorm_forward(hidden_states, self.weight.data, self.eps)
|
||||
|
||||
return _triton_rmsnorm_forward
|
||||
else:
|
||||
raise RuntimeError("Did not find rms-norm triton kernel")
|
||||
|
||||
|
||||
class ChatGLM2InferPolicy(ChatGLMForConditionalGenerationPolicy):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
@ -56,6 +75,16 @@ class ChatGLM2InferPolicy(ChatGLMModelPolicy):
|
||||
)
|
||||
# for rmsnorm and others, we need to check the shape
|
||||
|
||||
infer_forward = None
|
||||
|
||||
if HAS_TRITON_RMSNORM:
|
||||
infer_forward = get_triton_rmsnorm_forward()
|
||||
if infer_forward is not None:
|
||||
method_replacement = {"forward": partial(infer_forward)}
|
||||
self.append_or_create_method_replacement(
|
||||
description=method_replacement, policy=policy, target_key=RMSNorm
|
||||
)
|
||||
|
||||
self.set_pipeline_forward(
|
||||
model_cls=ChatGLMForConditionalGeneration,
|
||||
new_forward=ChatGLM2InferenceForwards.chatglm_for_conditional_generation_forward,
|
||||
|
@ -21,6 +21,7 @@ from ..modeling.llama import LlamaInferenceForwards
|
||||
|
||||
try:
|
||||
from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward as lightllm_rmsnorm_forward
|
||||
|
||||
HAS_TRITON_RMSNORM = True
|
||||
except:
|
||||
print("you should install triton from https://github.com/openai/triton")
|
||||
@ -30,6 +31,7 @@ if HAS_TRITON_RMSNORM:
|
||||
def get_triton_rmsnorm_forward():
|
||||
def _triton_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor):
|
||||
return lightllm_rmsnorm_forward(hidden_states, self.weight.data, self.variance_epsilon)
|
||||
|
||||
return _triton_rmsnorm_forward
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user