This commit is contained in:
flybird11111 2025-05-22 14:26:18 +08:00
parent 6a29abdefd
commit bad9c8ab24
2 changed files with 3 additions and 3 deletions

View File

@ -286,7 +286,6 @@ def new_from_pretrained(
config.name_or_path = pretrained_model_name_or_path
# Instantiate model.
# init_contexts = [no_init_weights(_enable=_fast_init)]
init_contexts = [no_init_weights()]
with ContextManagers(init_contexts):

View File

@ -1,7 +1,7 @@
import pytest
import torch
from packaging import version
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, LlamaConfig, apply_rotary_pos_emb
from colossalai.kernel.triton import decoding_fused_rotary_embedding
from tests.test_infer.test_kernels.triton.kernel_utils import (
@ -45,7 +45,8 @@ def test_rotary_emb(BATCH_SIZE, SEQ_LEN, H, D, dtype, use_new_kcache_layout):
# our crafted op equals to Transformers
x0 = torch.randn(BATCH_SIZE, H, SEQ_LEN, D, dtype=dtype)
x1 = torch.randn(BATCH_SIZE, H, SEQ_LEN, D, dtype=dtype)
emb = LlamaRotaryEmbedding(D)
config = LlamaConfig(max_position_embeddings=SEQ_LEN, num_attention_heads=H, hidden_size=H * D)
emb = LlamaRotaryEmbedding(config)
position_ids = torch.arange(TOTAL_TOKENS).reshape((BATCH_SIZE, SEQ_LEN))
cos, sin = emb(x0, position_ids)
embd_x0, _ = apply_rotary_pos_emb(x0, x1, cos, sin)