diff --git a/tests/kit/model_zoo/transformers/__init__.py b/tests/kit/model_zoo/transformers/__init__.py index be6d92f01..ac15a4bdb 100644 --- a/tests/kit/model_zoo/transformers/__init__.py +++ b/tests/kit/model_zoo/transformers/__init__.py @@ -8,6 +8,7 @@ from .gpt import * from .gptj import * from .llama import * from .opt import * +from .qwen2 import * from .sam import * from .t5 import * from .vit import * diff --git a/tests/kit/model_zoo/transformers/qwen2.py b/tests/kit/model_zoo/transformers/qwen2.py index 3110d8e53..e69de29bb 100644 --- a/tests/kit/model_zoo/transformers/qwen2.py +++ b/tests/kit/model_zoo/transformers/qwen2.py @@ -1,87 +0,0 @@ -import torch -import transformers - -from ..registry import ModelAttribute, model_zoo - -try: - from transformers import Qwen2Config - - HAS_QWEN2 = True -except ImportError: - HAS_QWEN2 = False - -if HAS_QWEN2: - # =============================== - # Register Qwen2 - # =============================== - - def data_gen(): - # the input ids are corresponding to the sentence - # 'Hello, my dog is cute' - # - # the code is give below: - # ----------------------------------- - # from transformers import Qwen2TokenizerFast - # tokenizer = Qwen2TokenizerFast.from_pretrained("Qwen/Qwen1.5-7B-Chat") - # input = 'Hello, my dog is cute' - # tokenized_input = tokenizer(input, return_tensors='pt').to('cuda') - # ----------------------------------- - - input_ids = torch.Tensor([[9707, 11, 847, 5562, 374, 18838], [9707, 11, 847, 5562, 374, 18838]]).long() - attention_mask = torch.Tensor([[1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1]]).long() - return dict(input_ids=input_ids, attention_mask=attention_mask) - - # label is needed for casual lm - def data_gen_for_casual_lm(): - data = data_gen() - labels = data["input_ids"].clone() - data["labels"] = labels - return data - - # transform the output to a dict - output_transform_fn = lambda x: x - - # function to get the loss - loss_fn = lambda output: output["last_hidden_state"].mean() - loss_fn_for_casual_lm = lambda output: output["loss"] - loss_fn_for_seq_classification = lambda output: output["logits"].mean() - - config = Qwen2Config( - hidden_size=128, - intermediate_size=256, - max_window_layers=4, - num_attention_heads=16, - num_hidden_layers=4, - num_key_value_heads=16, - ) - - config.pad_token_id = 0 - - # register the following models - # transformers.Qwen2Model, - # transformers.Qwen2ForCausalLM, - # transformers.Qwen2ForSequenceClassification, - model_zoo.register( - name="transformers_qwen2", - model_fn=lambda: transformers.Qwen2Model(config), - data_gen_fn=data_gen, - output_transform_fn=output_transform_fn, - loss_fn=loss_fn, - model_attribute=ModelAttribute(has_control_flow=True), - ) - model_zoo.register( - name="transformers_qwen2_for_casual_lm", - model_fn=lambda: transformers.Qwen2ForCausalLM(config), - data_gen_fn=data_gen_for_casual_lm, - output_transform_fn=output_transform_fn, - loss_fn=loss_fn_for_casual_lm, - model_attribute=ModelAttribute(has_control_flow=True), - ) - model_zoo.register( - name="transformers_qwen2_for_sequence_classification", - model_fn=lambda: transformers.Qwen2ForSequenceClassification(config), - data_gen_fn=data_gen, - output_transform_fn=output_transform_fn, - loss_fn=loss_fn_for_seq_classification, - model_attribute=ModelAttribute(has_control_flow=True), - )