From bad9c8ab2493bc418c8998ca07c47699cb1225c0 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Thu, 22 May 2025 14:26:18 +0800 Subject: [PATCH] fix --- colossalai/lazy/pretrained.py | 1 - .../test_kernels/triton/test_rotary_embdding_unpad.py | 5 +++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/colossalai/lazy/pretrained.py b/colossalai/lazy/pretrained.py index 66f4cf3bb..684be9223 100644 --- a/colossalai/lazy/pretrained.py +++ b/colossalai/lazy/pretrained.py @@ -286,7 +286,6 @@ def new_from_pretrained( config.name_or_path = pretrained_model_name_or_path # Instantiate model. - # init_contexts = [no_init_weights(_enable=_fast_init)] init_contexts = [no_init_weights()] with ContextManagers(init_contexts): diff --git a/tests/test_infer/test_kernels/triton/test_rotary_embdding_unpad.py b/tests/test_infer/test_kernels/triton/test_rotary_embdding_unpad.py index 78b7ba81c..f48bab377 100644 --- a/tests/test_infer/test_kernels/triton/test_rotary_embdding_unpad.py +++ b/tests/test_infer/test_kernels/triton/test_rotary_embdding_unpad.py @@ -1,7 +1,7 @@ import pytest import torch from packaging import version -from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb +from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, LlamaConfig, apply_rotary_pos_emb from colossalai.kernel.triton import decoding_fused_rotary_embedding from tests.test_infer.test_kernels.triton.kernel_utils import ( @@ -45,7 +45,8 @@ def test_rotary_emb(BATCH_SIZE, SEQ_LEN, H, D, dtype, use_new_kcache_layout): # our crafted op equals to Transformers x0 = torch.randn(BATCH_SIZE, H, SEQ_LEN, D, dtype=dtype) x1 = torch.randn(BATCH_SIZE, H, SEQ_LEN, D, dtype=dtype) - emb = LlamaRotaryEmbedding(D) + config = LlamaConfig(max_position_embeddings=SEQ_LEN, num_attention_heads=H, hidden_size=H * D) + emb = LlamaRotaryEmbedding(config) position_ids = torch.arange(TOTAL_TOKENS).reshape((BATCH_SIZE, SEQ_LEN)) cos, sin = emb(x0, position_ids) embd_x0, _ = apply_rotary_pos_emb(x0, x1, cos, sin)