From 5512bdf1fc2047ab47ea6484b876fd9cb3d1a727 Mon Sep 17 00:00:00 2001 From: Wenhao Chen Date: Mon, 4 Mar 2024 11:22:34 +0800 Subject: [PATCH] fix: modify model config and add Qwen2RMSNorm --- colossalai/shardformer/layer/normalization.py | 2 +- tests/kit/model_zoo/transformers/qwen2.py | 87 +++++++++++++++++++ 2 files changed, 88 insertions(+), 1 deletion(-) create mode 100644 tests/kit/model_zoo/transformers/qwen2.py diff --git a/colossalai/shardformer/layer/normalization.py b/colossalai/shardformer/layer/normalization.py index 43dd153af..bab0e4e1d 100644 --- a/colossalai/shardformer/layer/normalization.py +++ b/colossalai/shardformer/layer/normalization.py @@ -276,7 +276,7 @@ class FusedRMSNorm(BaseLayerNorm): LazyInitContext.materialize(module) # to check if it is huggingface LlamaRMSNorm or MistralRMSNorm - if module.__class__.__name__ in ["LlamaRMSNorm", "MistralRMSNorm"]: + if module.__class__.__name__ in ["LlamaRMSNorm", "Qwen2RMSNorm", "MistralRMSNorm"]: normalized_shape = module.weight.shape[0] eps = module.variance_epsilon elementwise_affine = True diff --git a/tests/kit/model_zoo/transformers/qwen2.py b/tests/kit/model_zoo/transformers/qwen2.py new file mode 100644 index 000000000..3110d8e53 --- /dev/null +++ b/tests/kit/model_zoo/transformers/qwen2.py @@ -0,0 +1,87 @@ +import torch +import transformers + +from ..registry import ModelAttribute, model_zoo + +try: + from transformers import Qwen2Config + + HAS_QWEN2 = True +except ImportError: + HAS_QWEN2 = False + +if HAS_QWEN2: + # =============================== + # Register Qwen2 + # =============================== + + def data_gen(): + # the input ids are corresponding to the sentence + # 'Hello, my dog is cute' + # + # the code is give below: + # ----------------------------------- + # from transformers import Qwen2TokenizerFast + # tokenizer = Qwen2TokenizerFast.from_pretrained("Qwen/Qwen1.5-7B-Chat") + # input = 'Hello, my dog is cute' + # tokenized_input = tokenizer(input, return_tensors='pt').to('cuda') + # ----------------------------------- + + input_ids = torch.Tensor([[9707, 11, 847, 5562, 374, 18838], [9707, 11, 847, 5562, 374, 18838]]).long() + attention_mask = torch.Tensor([[1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1]]).long() + return dict(input_ids=input_ids, attention_mask=attention_mask) + + # label is needed for casual lm + def data_gen_for_casual_lm(): + data = data_gen() + labels = data["input_ids"].clone() + data["labels"] = labels + return data + + # transform the output to a dict + output_transform_fn = lambda x: x + + # function to get the loss + loss_fn = lambda output: output["last_hidden_state"].mean() + loss_fn_for_casual_lm = lambda output: output["loss"] + loss_fn_for_seq_classification = lambda output: output["logits"].mean() + + config = Qwen2Config( + hidden_size=128, + intermediate_size=256, + max_window_layers=4, + num_attention_heads=16, + num_hidden_layers=4, + num_key_value_heads=16, + ) + + config.pad_token_id = 0 + + # register the following models + # transformers.Qwen2Model, + # transformers.Qwen2ForCausalLM, + # transformers.Qwen2ForSequenceClassification, + model_zoo.register( + name="transformers_qwen2", + model_fn=lambda: transformers.Qwen2Model(config), + data_gen_fn=data_gen, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn, + model_attribute=ModelAttribute(has_control_flow=True), + ) + model_zoo.register( + name="transformers_qwen2_for_casual_lm", + model_fn=lambda: transformers.Qwen2ForCausalLM(config), + data_gen_fn=data_gen_for_casual_lm, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_casual_lm, + model_attribute=ModelAttribute(has_control_flow=True), + ) + model_zoo.register( + name="transformers_qwen2_for_sequence_classification", + model_fn=lambda: transformers.Qwen2ForSequenceClassification(config), + data_gen_fn=data_gen, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_seq_classification, + model_attribute=ModelAttribute(has_control_flow=True), + )