[hotfix]fix chatglm rmsnorm (#5079)

* fix chatglm

* add error

* fix bugs

---------

Co-authored-by: CjhHa1 <cjh18671720497outlook.com>
This commit is contained in:
Jianghai 2023-11-22 17:59:22 +08:00 committed by GitHub
parent 67a07e6f64
commit cb450c2861
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 48 additions and 9 deletions

View File

@ -408,12 +408,20 @@ class ChatGLM2InferenceForwards:
query_layer = query_layer.reshape(
-1, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head
)
key_layer = key_layer.reshape(
-1, self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head
)
value_layer = value_layer.reshape(
-1, self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head
)
if self.multi_query_attention:
key_layer = key_layer.reshape(
-1, self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head
)
value_layer = value_layer.reshape(
-1, self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head
)
else:
key_layer = key_layer.reshape(
-1, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head
)
value_layer = value_layer.reshape(
-1, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head
)
if infer_state.is_context_stage:
# first token generation:

View File

@ -1,5 +1,7 @@
from functools import partial
from typing import List
import torch
import torch.nn as nn
from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import (
@ -7,23 +9,40 @@ from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import (
ChatGLMModel,
GLMBlock,
GLMTransformer,
RMSNorm,
SelfAttention,
)
# import colossalai
from colossalai.shardformer.policies.chatglm2 import ChatGLMModelPolicy
from colossalai.shardformer.policies.chatglm2 import ChatGLMForConditionalGenerationPolicy
from ..modeling._utils import init_to_get_rotary
from ..modeling.chatglm2 import ChatGLM2InferenceForwards
try:
from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward as lightllm_rmsnorm_forward
HAS_TRITON_RMSNORM = True
except:
print("you should install triton from https://github.com/openai/triton")
print("Did not find rms-norm triton kernel")
print(
"You can use the following command to install: pip install git+https://github.com/ModelTC/lightllm.git@ece7b43f8a6dfa74027adc77c2c176cff28c76c8"
)
HAS_TRITON_RMSNORM = False
class ChatGLM2InferPolicy(ChatGLMModelPolicy):
def get_triton_rmsnorm_forward():
if HAS_TRITON_RMSNORM:
def _triton_rmsnorm_forward(self: RMSNorm, hidden_states: torch.Tensor):
return lightllm_rmsnorm_forward(hidden_states, self.weight.data, self.eps)
return _triton_rmsnorm_forward
else:
raise RuntimeError("Did not find rms-norm triton kernel")
class ChatGLM2InferPolicy(ChatGLMForConditionalGenerationPolicy):
def __init__(self) -> None:
super().__init__()
@ -56,6 +75,16 @@ class ChatGLM2InferPolicy(ChatGLMModelPolicy):
)
# for rmsnorm and others, we need to check the shape
infer_forward = None
if HAS_TRITON_RMSNORM:
infer_forward = get_triton_rmsnorm_forward()
if infer_forward is not None:
method_replacement = {"forward": partial(infer_forward)}
self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=RMSNorm
)
self.set_pipeline_forward(
model_cls=ChatGLMForConditionalGeneration,
new_forward=ChatGLM2InferenceForwards.chatglm_for_conditional_generation_forward,

View File

@ -21,6 +21,7 @@ from ..modeling.llama import LlamaInferenceForwards
try:
from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward as lightllm_rmsnorm_forward
HAS_TRITON_RMSNORM = True
except:
print("you should install triton from https://github.com/openai/triton")
@ -30,6 +31,7 @@ if HAS_TRITON_RMSNORM:
def get_triton_rmsnorm_forward():
def _triton_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor):
return lightllm_rmsnorm_forward(hidden_states, self.weight.data, self.variance_epsilon)
return _triton_rmsnorm_forward