mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-10 21:40:02 +00:00
[shardformer] fix chatglm implementation (#5644)
* [shardformer] fix chatglm policy * [shardformer] fix chatglm flash attn * [shardformer] update readme * [shardformer] fix chatglm init * [shardformer] fix chatglm test * [pipeline] fix chatglm merge batch
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
import torch
|
||||
|
||||
from colossalai.shardformer.modeling.chatglm2_6b.configuration_chatglm import ChatGLMConfig
|
||||
from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration, ChatGLMModel
|
||||
from torch.nn import init
|
||||
from transformers import AutoConfig, AutoModelForCausalLM
|
||||
|
||||
from ..registry import ModelAttribute, model_zoo
|
||||
|
||||
@@ -34,19 +33,26 @@ loss_fn_for_chatglm_model = lambda x: torch.nn.functional.mse_loss(
|
||||
)
|
||||
loss_fn = lambda x: x["loss"]
|
||||
|
||||
config = ChatGLMConfig(
|
||||
config = AutoConfig.from_pretrained(
|
||||
"THUDM/chatglm2-6b",
|
||||
trust_remote_code=True,
|
||||
num_layers=2,
|
||||
padded_vocab_size=65024,
|
||||
hidden_size=64,
|
||||
ffn_hidden_size=214,
|
||||
num_attention_heads=8,
|
||||
kv_channels=16,
|
||||
rmsnorm=True,
|
||||
original_rope=True,
|
||||
use_cache=True,
|
||||
multi_query_attention=False,
|
||||
torch_dtype=torch.float32,
|
||||
)
|
||||
|
||||
infer_config = ChatGLMConfig(
|
||||
|
||||
infer_config = AutoConfig.from_pretrained(
|
||||
"THUDM/chatglm2-6b",
|
||||
trust_remote_code=True,
|
||||
num_layers=2,
|
||||
padded_vocab_size=65024,
|
||||
hidden_size=128,
|
||||
@@ -60,18 +66,18 @@ infer_config = ChatGLMConfig(
|
||||
torch_dtype=torch.float32,
|
||||
)
|
||||
|
||||
model_zoo.register(
|
||||
name="transformers_chatglm",
|
||||
model_fn=lambda: ChatGLMModel(config, empty_init=False),
|
||||
data_gen_fn=data_gen,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_for_chatglm_model,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
|
||||
def init_chatglm():
|
||||
model = AutoModelForCausalLM.from_config(config, empty_init=False, trust_remote_code=True)
|
||||
for m in model.modules():
|
||||
if m.__class__.__name__ == "RMSNorm":
|
||||
init.ones_(m.weight)
|
||||
return model
|
||||
|
||||
|
||||
model_zoo.register(
|
||||
name="transformers_chatglm_for_conditional_generation",
|
||||
model_fn=lambda: ChatGLMForConditionalGeneration(config, empty_init=False),
|
||||
model_fn=init_chatglm,
|
||||
data_gen_fn=data_gen_for_conditional_generation,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn,
|
||||
|
Reference in New Issue
Block a user