mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-23 10:30:03 +00:00
[upgrade]Upgrade vit (#6308)
* fix * fix * fix rotate embedding test * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb
|
||||
from transformers.models.llama.modeling_llama import LlamaConfig, LlamaRotaryEmbedding, apply_rotary_pos_emb
|
||||
|
||||
from colossalai.kernel.kernel_loader import InferenceOpsLoader
|
||||
|
||||
@@ -33,7 +33,8 @@ def test_rotary_emb(BATCH_SIZE, SEQ_LEN, H, K_H, D, dtype):
|
||||
|
||||
position_ids = torch.arange(TOTAL_TOKENS).reshape((BATCH_SIZE, SEQ_LEN))
|
||||
|
||||
emb = LlamaRotaryEmbedding(D)
|
||||
config = LlamaConfig(max_position_embeddings=SEQ_LEN, num_attention_heads=H, hidden_size=H * D)
|
||||
emb = LlamaRotaryEmbedding(config)
|
||||
|
||||
cos, sin = emb(x0, position_ids)
|
||||
embd_x0, _ = apply_rotary_pos_emb(x0, x1, cos, sin)
|
||||
|
Reference in New Issue
Block a user