[upgrade]Upgrade vit (#6308)

* fix

* fix

* fix rotate embedding test

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
flybird11111
2025-05-21 16:14:20 +08:00
committed by GitHub
parent d0e13b85fd
commit 04516bb756
4 changed files with 6 additions and 7 deletions

View File

@@ -1,7 +1,7 @@
import numpy as np
import pytest
import torch
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb
from transformers.models.llama.modeling_llama import LlamaConfig, LlamaRotaryEmbedding, apply_rotary_pos_emb
from colossalai.kernel.kernel_loader import InferenceOpsLoader
@@ -33,7 +33,8 @@ def test_rotary_emb(BATCH_SIZE, SEQ_LEN, H, K_H, D, dtype):
position_ids = torch.arange(TOTAL_TOKENS).reshape((BATCH_SIZE, SEQ_LEN))
emb = LlamaRotaryEmbedding(D)
config = LlamaConfig(max_position_embeddings=SEQ_LEN, num_attention_heads=H, hidden_size=H * D)
emb = LlamaRotaryEmbedding(config)
cos, sin = emb(x0, position_ids)
embd_x0, _ = apply_rotary_pos_emb(x0, x1, cos, sin)