mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-26 15:32:22 +00:00
fix
This commit is contained in:
parent
6a29abdefd
commit
bad9c8ab24
@ -286,7 +286,6 @@ def new_from_pretrained(
|
||||
config.name_or_path = pretrained_model_name_or_path
|
||||
|
||||
# Instantiate model.
|
||||
# init_contexts = [no_init_weights(_enable=_fast_init)]
|
||||
init_contexts = [no_init_weights()]
|
||||
|
||||
with ContextManagers(init_contexts):
|
||||
|
@ -1,7 +1,7 @@
|
||||
import pytest
|
||||
import torch
|
||||
from packaging import version
|
||||
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb
|
||||
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, LlamaConfig, apply_rotary_pos_emb
|
||||
|
||||
from colossalai.kernel.triton import decoding_fused_rotary_embedding
|
||||
from tests.test_infer.test_kernels.triton.kernel_utils import (
|
||||
@ -45,7 +45,8 @@ def test_rotary_emb(BATCH_SIZE, SEQ_LEN, H, D, dtype, use_new_kcache_layout):
|
||||
# our crafted op equals to Transformers
|
||||
x0 = torch.randn(BATCH_SIZE, H, SEQ_LEN, D, dtype=dtype)
|
||||
x1 = torch.randn(BATCH_SIZE, H, SEQ_LEN, D, dtype=dtype)
|
||||
emb = LlamaRotaryEmbedding(D)
|
||||
config = LlamaConfig(max_position_embeddings=SEQ_LEN, num_attention_heads=H, hidden_size=H * D)
|
||||
emb = LlamaRotaryEmbedding(config)
|
||||
position_ids = torch.arange(TOTAL_TOKENS).reshape((BATCH_SIZE, SEQ_LEN))
|
||||
cos, sin = emb(x0, position_ids)
|
||||
embd_x0, _ = apply_rotary_pos_emb(x0, x1, cos, sin)
|
||||
|
Loading…
Reference in New Issue
Block a user