[Fix] Fix & Update Inference Tests (compatibility w/ main)

This commit is contained in:
Yuanheng Zhao
2024-05-05 16:28:56 +00:00
parent 56ed09aba5
commit 8754abae24
30 changed files with 32 additions and 30 deletions

View File

@@ -0,0 +1,320 @@
from itertools import product
import numpy as np
import pytest
import torch
from colossalai.inference.modeling.models.nopadding_baichuan import get_alibi_slopes
from colossalai.kernel.kernel_loader import InferenceOpsLoader
from colossalai.utils import get_current_device
from tests.test_infer.test_kernels.triton.test_context_attn_unpad import generate_alibi_mask
inference_ops = InferenceOpsLoader().load()
from tests.test_infer.test_kernels.triton.kernel_utils import (
convert_kv_unpad_to_padded,
create_attention_mask,
generate_caches_and_block_tables_v3,
generate_caches_and_block_tables_vllm,
torch_attn_ref,
)
q_len = 1
def prepare_data(
BATCH_SIZE: int,
HEAD_SIZE: int,
NUM_ATTN_HEADS: int,
NUM_KV_HEADS: int,
MAX_SEQ_LEN: int,
dtype=torch.float16,
device="cuda",
):
# Use the provided maximum sequence length for each sequence when testing with teh same context length,
# otherwise generate random context lengths.
# returns
# q [BATCH_SIZE, NUM_ATTN_HEADS, HEAD_SIZE]
# k_unpad/v_unpad [num_tokens, NUM_KV_HEADS, HEAD_SIZE]
kv_lengths = torch.randint(low=1, high=MAX_SEQ_LEN, size=(BATCH_SIZE,), dtype=torch.int32, device=device)
num_tokens = torch.sum(kv_lengths).item()
q_size = (BATCH_SIZE, q_len, NUM_ATTN_HEADS, HEAD_SIZE)
q = torch.empty(size=q_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5).transpose(1, 2)
kv_size = (num_tokens, 2 * NUM_KV_HEADS, HEAD_SIZE)
kv_unpad = torch.empty(size=kv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5)
k_unpad, v_unpad = torch.split(kv_unpad, [NUM_KV_HEADS, NUM_KV_HEADS], dim=-2)
return q, k_unpad, v_unpad, kv_lengths
def numpy_allclose(x, y, rtol, atol):
x_numpy = x.detach().cpu().numpy()
y_numpy = y.detach().cpu().numpy()
np.testing.assert_allclose(x_numpy, y_numpy, rtol=rtol, atol=atol)
@pytest.mark.parametrize("BATCH_SIZE", [1, 4, 7, 32])
@pytest.mark.parametrize("BLOCK_SIZE", [8, 16, 32])
@pytest.mark.parametrize("MAX_NUM_BLOCKS_PER_SEQ", [1, 8, 32])
@pytest.mark.parametrize("HEAD_SIZE", [64, 128])
@pytest.mark.parametrize("NUM_ATTN_HEADS", [16])
@pytest.mark.parametrize("KV_GROUP_NUM", [1, 2, 16])
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32])
@pytest.mark.parametrize("use_alibi_slopes", [True, False])
def test_flash_decoding_attention(
BATCH_SIZE, BLOCK_SIZE, MAX_NUM_BLOCKS_PER_SEQ, HEAD_SIZE, NUM_ATTN_HEADS, KV_GROUP_NUM, dtype, use_alibi_slopes
):
torch.manual_seed(123)
torch.cuda.empty_cache()
torch.cuda.synchronize()
torch.cuda.reset_peak_memory_stats()
NUM_KV_HEADS = NUM_ATTN_HEADS // KV_GROUP_NUM
assert isinstance(NUM_KV_HEADS, int) and NUM_KV_HEADS > 0, "Invalid number of kv heads."
MAX_SEQ_LEN = BLOCK_SIZE * MAX_NUM_BLOCKS_PER_SEQ
device = get_current_device()
if use_alibi_slopes:
alibi_slopes = get_alibi_slopes(NUM_ATTN_HEADS, device)
else:
alibi_slopes = None
q, k_unpad, v_unpad, kv_seq_lengths = prepare_data(
BATCH_SIZE, HEAD_SIZE, NUM_ATTN_HEADS, NUM_KV_HEADS, MAX_SEQ_LEN, dtype, device
)
k_cache, v_cache, block_tables = generate_caches_and_block_tables_v3(
k_unpad, v_unpad, kv_seq_lengths, BATCH_SIZE, MAX_NUM_BLOCKS_PER_SEQ, BLOCK_SIZE, dtype, device
)
block_tables = block_tables.to(device=device)
max_seq_len_across_batch = kv_seq_lengths.max().item()
kv_max_split_num = (max_seq_len_across_batch + BLOCK_SIZE - 1) // BLOCK_SIZE
output = torch.empty((BATCH_SIZE, NUM_ATTN_HEADS, HEAD_SIZE), dtype=dtype, device=device)
sm_scale = 1.0 / (HEAD_SIZE**0.5)
k_torch = convert_kv_unpad_to_padded(k_unpad, kv_seq_lengths, BATCH_SIZE, max_seq_len_across_batch)
v_torch = convert_kv_unpad_to_padded(v_unpad, kv_seq_lengths, BATCH_SIZE, max_seq_len_across_batch)
torch_padding_mask = create_attention_mask(kv_seq_lengths, BATCH_SIZE, q_len, max_seq_len_across_batch, device)
if use_alibi_slopes:
alibi_mask = generate_alibi_mask(alibi_slopes, NUM_ATTN_HEADS, max_seq_len_across_batch, device)
torch_padding_mask = torch_padding_mask + alibi_mask
if len(torch_padding_mask.size()) == 4:
torch_padding_mask = torch_padding_mask[:, :, -1:, :]
else:
torch_padding_mask = torch_padding_mask[:, -1:, :]
mid_output = torch.empty(
size=(BATCH_SIZE, NUM_ATTN_HEADS, kv_max_split_num, HEAD_SIZE), dtype=torch.float32, device=device
)
mid_output_lse = torch.empty(
size=(BATCH_SIZE, NUM_ATTN_HEADS, kv_max_split_num), dtype=torch.float32, device=device
)
if dtype == torch.float16:
rtol = 1e-3
atol = 1e-3
high_precision_q = q.to(torch.float32)
high_precision_k_torch = k_torch.to(torch.float32)
high_precision_v_torch = v_torch.to(torch.float32)
out_ref = torch_attn_ref(
high_precision_q,
high_precision_k_torch,
high_precision_v_torch,
torch_padding_mask,
BATCH_SIZE,
q_len,
max_seq_len_across_batch,
NUM_ATTN_HEADS,
NUM_KV_HEADS,
HEAD_SIZE,
).to(torch.float16)
else:
rtol = 1e-5
atol = 1e-7
out_ref = torch_attn_ref(
q,
k_torch,
v_torch,
torch_padding_mask,
BATCH_SIZE,
q_len,
max_seq_len_across_batch,
NUM_ATTN_HEADS,
NUM_KV_HEADS,
HEAD_SIZE,
)
inference_ops.flash_decoding_attention(
output,
q.squeeze(2),
k_cache,
v_cache,
kv_seq_lengths,
block_tables,
BLOCK_SIZE,
max_seq_len_across_batch,
mid_output,
mid_output_lse,
alibi_slopes,
sm_scale,
)
# The alibi may introduce relatively large errors
if use_alibi_slopes:
rtol = 1e0
numpy_allclose(out_ref, output, rtol=rtol, atol=atol)
try:
from vllm._C import ops as vllm_ops # noqa
HAS_VLLM = True
except ImportError:
HAS_VLLM = False
print("The subsequent test requires vllm. Please refer to https://github.com/vllm-project/vllm")
@pytest.mark.skipif(not HAS_VLLM, reason="requires vllm")
@pytest.mark.parametrize("BATCH_SIZE", [1, 4, 7, 32])
@pytest.mark.parametrize("BLOCK_SIZE", [8, 16, 32])
@pytest.mark.parametrize("MAX_NUM_BLOCKS_PER_SEQ", [1, 8, 32])
@pytest.mark.parametrize("HEAD_SIZE", [64, 128])
@pytest.mark.parametrize("NUM_ATTN_HEADS", [16])
@pytest.mark.parametrize("KV_GROUP_NUM", [1, 2, 16])
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32])
@pytest.mark.parametrize("use_alibi_slopes", [True, False])
def test_vllm_flash_decoding_attention(
BATCH_SIZE, BLOCK_SIZE, MAX_NUM_BLOCKS_PER_SEQ, HEAD_SIZE, NUM_ATTN_HEADS, KV_GROUP_NUM, dtype, use_alibi_slopes
):
torch.manual_seed(123)
torch.cuda.empty_cache()
torch.cuda.synchronize()
torch.cuda.reset_peak_memory_stats()
NUM_KV_HEADS = NUM_ATTN_HEADS // KV_GROUP_NUM
assert isinstance(NUM_KV_HEADS, int) and NUM_KV_HEADS > 0, "Invalid number of kv heads."
MAX_SEQ_LEN = BLOCK_SIZE * MAX_NUM_BLOCKS_PER_SEQ
device = get_current_device()
q, k_unpad, v_unpad, kv_seq_lengths = prepare_data(
BATCH_SIZE, HEAD_SIZE, NUM_ATTN_HEADS, NUM_KV_HEADS, MAX_SEQ_LEN, dtype, device
)
k_cache, v_cache, block_tables = generate_caches_and_block_tables_vllm(
k_unpad, v_unpad, kv_seq_lengths, BATCH_SIZE, MAX_NUM_BLOCKS_PER_SEQ, BLOCK_SIZE, dtype, device
)
block_tables = block_tables.to(device=device)
max_seq_len_across_batch = kv_seq_lengths.max().item()
output = torch.empty((BATCH_SIZE, NUM_ATTN_HEADS, HEAD_SIZE), dtype=dtype, device=device)
sm_scale = 1.0 / (HEAD_SIZE**0.5)
kv_scale = 1.0
k_torch = convert_kv_unpad_to_padded(k_unpad, kv_seq_lengths, BATCH_SIZE, max_seq_len_across_batch)
v_torch = convert_kv_unpad_to_padded(v_unpad, kv_seq_lengths, BATCH_SIZE, max_seq_len_across_batch)
torch_padding_mask = create_attention_mask(kv_seq_lengths, BATCH_SIZE, q_len, max_seq_len_across_batch, device)
if use_alibi_slopes:
alibi_slopes = get_alibi_slopes(NUM_ATTN_HEADS, device)
alibi_mask = generate_alibi_mask(alibi_slopes, NUM_ATTN_HEADS, max_seq_len_across_batch, device)
torch_padding_mask = torch_padding_mask + alibi_mask
if len(torch_padding_mask.size()) == 4:
torch_padding_mask = torch_padding_mask[:, :, -1:, :]
else:
torch_padding_mask = torch_padding_mask[:, -1:, :]
else:
alibi_slopes = None
if dtype == torch.float16:
rtol = 1e-3
atol = 1e-3
high_precision_q = q.to(torch.float32)
high_precision_k_torch = k_torch.to(torch.float32)
high_precision_v_torch = v_torch.to(torch.float32)
out_ref = torch_attn_ref(
high_precision_q,
high_precision_k_torch,
high_precision_v_torch,
torch_padding_mask,
BATCH_SIZE,
q_len,
max_seq_len_across_batch,
NUM_ATTN_HEADS,
NUM_KV_HEADS,
HEAD_SIZE,
).to(torch.float16)
else:
rtol = 1e-5
atol = 1e-7
out_ref = torch_attn_ref(
q,
k_torch,
v_torch,
torch_padding_mask,
BATCH_SIZE,
q_len,
max_seq_len_across_batch,
NUM_ATTN_HEADS,
NUM_KV_HEADS,
HEAD_SIZE,
)
vllm_ops.paged_attention_v1(
output,
q.squeeze(2),
k_cache,
v_cache,
NUM_KV_HEADS,
sm_scale,
block_tables,
kv_seq_lengths,
BLOCK_SIZE,
max_seq_len_across_batch,
alibi_slopes,
"auto",
kv_scale,
)
# The alibi may introduce relatively large errors
if use_alibi_slopes:
rtol = 1e0
numpy_allclose(out_ref, output, rtol=rtol, atol=atol)
if __name__ == "__main__":
BATCH_SIZE = [1, 4, 7, 32]
BLOCK_SIZE = [8, 16, 32]
MAX_NUM_BLOCKS_PER_SEQ = [1, 8, 32]
HEAD_SIZE = [64, 128]
NUM_ATTN_HEADS = [16]
KV_GROUP_NUM = [1, 2, 16]
DTYPE = [torch.float16, torch.float32]
test_combinations = list(
product(BATCH_SIZE, BLOCK_SIZE, MAX_NUM_BLOCKS_PER_SEQ, HEAD_SIZE, NUM_ATTN_HEADS, KV_GROUP_NUM, DTYPE)
)
for (
batch_size,
block_size,
max_num_blocks_per_seq,
head_size,
num_attn_heads,
kv_group_num,
dtype,
) in test_combinations:
test_flash_decoding_attention(
batch_size, block_size, max_num_blocks_per_seq, head_size, num_attn_heads, kv_group_num, dtype, True
)

View File

@@ -0,0 +1,53 @@
import numpy as np
import pytest
import torch
from colossalai.kernel.kernel_loader import InferenceOpsLoader
from tests.test_infer.test_kernels.triton.test_xine_copy import get_cos_sin
inference_ops = InferenceOpsLoader().load()
def numpy_equal(x, y):
x_numpy = x.detach().cpu().numpy()
y_numpy = y.detach().cpu().numpy()
np.testing.assert_equal(x_numpy, y_numpy)
@pytest.mark.parametrize("BATCH_SIZE", [4])
@pytest.mark.parametrize("MAX_SEQ_LEN", [64])
@pytest.mark.parametrize("HEAD_DIM", [64])
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32])
def test_get_cos_and_sin(BATCH_SIZE, MAX_SEQ_LEN, HEAD_DIM, dtype):
MAX_TOTAL_TOKENS = BATCH_SIZE * MAX_SEQ_LEN
cos_cache = torch.randn((MAX_TOTAL_TOKENS, HEAD_DIM), dtype=dtype, device="cuda")
sin_cache = torch.randn((MAX_TOTAL_TOKENS, HEAD_DIM), dtype=dtype, device="cuda")
lengths = torch.randint(2, MAX_SEQ_LEN, (BATCH_SIZE,), device="cuda").to(torch.int32)
max_seq_len_in_batch = lengths.max()
# prefill
cos_ref, sin_ref = get_cos_sin(lengths, cos_cache, sin_cache, is_prompts=True, dtype=dtype)
cos = torch.zeros_like(cos_ref)
sin = torch.zeros_like(sin_ref)
inference_ops.get_cos_and_sin(cos_cache, sin_cache, cos, sin, lengths, max_seq_len_in_batch, True)
numpy_equal(cos, cos_ref)
numpy_equal(sin, sin_ref)
# decoding
ncos_ref, nsin_ref = get_cos_sin(lengths, cos_cache, sin_cache, is_prompts=False, dtype=dtype)
cos = torch.zeros_like(ncos_ref)
sin = torch.zeros_like(nsin_ref)
inference_ops.get_cos_and_sin(cos_cache, sin_cache, cos, sin, lengths, max_seq_len_in_batch, False)
numpy_equal(cos, ncos_ref)
numpy_equal(sin, nsin_ref)
if __name__ == "__main__":
test_get_cos_and_sin(16, 4096, 256, torch.float16)

View File

@@ -0,0 +1,157 @@
import pytest
import torch
import torch.nn.functional as F
from colossalai.kernel.kernel_loader import InferenceOpsLoader
from colossalai.utils import get_current_device
from tests.test_infer.test_kernels.triton.kernel_utils import (
generate_caches_and_block_tables_v3,
mock_alloc_single_token,
)
inference_ops = InferenceOpsLoader().load()
HEAD_DIM = 72
def prepare_data(
bsz,
num_kv_heads,
block_size,
max_num_blocks_per_seq,
context_lengths,
device="cuda",
dtype=torch.float16,
):
num_tokens = torch.sum(context_lengths).item()
max_seq_len_in_batch = context_lengths.max()
cu_seqlens = F.pad(torch.cumsum(context_lengths, dim=0, dtype=torch.torch.int32), (1, 0))
kv_size = (num_tokens, num_kv_heads, HEAD_DIM)
key = torch.empty(size=kv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5)
value = torch.empty(size=kv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5)
k_cache_ref, v_cache_ref, block_tables = generate_caches_and_block_tables_v3(
key, value, context_lengths, bsz, max_num_blocks_per_seq, block_size, dtype, device
)
block_tables = block_tables.to(device=device)
k_cache = torch.zeros_like(k_cache_ref)
v_cache = torch.zeros_like(v_cache_ref)
return key, value, k_cache, v_cache, cu_seqlens, block_tables, max_seq_len_in_batch, k_cache_ref, v_cache_ref
def run_decode_copy_kv_to_caches(
bsz: int,
block_size: int,
max_num_blocks_per_seq: int,
num_kv_heads: int,
same_context_len: bool,
):
torch.manual_seed(123)
torch.cuda.empty_cache()
torch.cuda.synchronize()
torch.cuda.reset_peak_memory_stats()
n = 1
max_seq_len = block_size * max_num_blocks_per_seq
dtype = torch.float32
device = get_current_device()
assert max_seq_len > n, "max_seq_len must be greater than n"
past_kv_seq_lengths = (
torch.tensor([max_seq_len - n for _ in range(bsz)], dtype=torch.int32, device=device)
if same_context_len
else torch.randint(low=1, high=max_seq_len - n, size=(bsz,), dtype=torch.int32, device=device)
)
key, value, k_cache, v_cache, _, block_tables, _, _, _ = prepare_data(
bsz, num_kv_heads, block_size, max_num_blocks_per_seq, past_kv_seq_lengths, device, dtype
)
new_k = torch.randn((bsz, num_kv_heads, HEAD_DIM), dtype=dtype, device=device)
new_v = torch.randn((bsz, num_kv_heads, HEAD_DIM), dtype=dtype, device=device)
# mock allocating blocks for the new k/v and update block tables
for _ in range(n):
mock_alloc_single_token(block_tables, past_kv_seq_lengths, block_size)
past_kv_seq_lengths += 1
inference_ops.decode_kv_cache_memcpy(new_k, new_v, k_cache, v_cache, past_kv_seq_lengths, block_tables)
past_kv_seq_len = past_kv_seq_lengths - 1
target_block_ids = block_tables[range(0, block_tables.size(0)), past_kv_seq_len // block_size]
offsets_in_block = past_kv_seq_len % block_size
k_target = k_cache[target_block_ids, :, :, offsets_in_block, :]
k_source = new_k.squeeze()
v_target = v_cache[target_block_ids, :, offsets_in_block, :]
k_target = k_target.reshape(v_target.shape)
v_source = new_v.squeeze()
assert k_target.shape == k_source.shape
assert torch.equal(k_target, k_source)
assert v_target.shape == v_source.shape
assert torch.equal(v_target, v_source)
def run_context_copy_kv_to_cache(
bsz: int,
block_size: int,
max_num_blocks_per_seq: int,
num_kv_heads: int,
same_context_len: bool,
):
torch.manual_seed(123)
assert isinstance(num_kv_heads, int) and num_kv_heads > 0, "Invalid number of kv heads."
max_seq_len = max_num_blocks_per_seq * block_size
dtype = torch.float16
device = get_current_device()
if same_context_len:
context_lengths = torch.tensor([max_seq_len for _ in range(bsz)], dtype=torch.int32, device=device)
else:
context_lengths = torch.randint(low=1, high=max_seq_len, size=(bsz,), dtype=torch.int32, device=device)
(
key,
value,
k_cache,
v_cache,
cu_seqlens,
block_tables,
max_seq_len_in_batch,
k_cache_ref,
v_cache_ref,
) = prepare_data(bsz, num_kv_heads, block_size, max_num_blocks_per_seq, context_lengths, device, dtype)
inference_ops.context_kv_cache_memcpy(
key, value, k_cache, v_cache, context_lengths, cu_seqlens, block_tables, max_seq_len_in_batch
)
assert torch.equal(k_cache, k_cache_ref)
assert torch.equal(v_cache, v_cache_ref)
@pytest.mark.parametrize("bsz", [4, 7, 32])
@pytest.mark.parametrize("block_size", [16, 32, 64])
@pytest.mark.parametrize("max_num_blocks_per_seq", [8, 32])
@pytest.mark.parametrize("num_kv_heads", [16])
@pytest.mark.parametrize("same_context_len", [True, False])
def test_kv_cache_memcopy(
bsz: int,
block_size: int,
max_num_blocks_per_seq: int,
num_kv_heads: int,
same_context_len: bool,
):
run_context_copy_kv_to_cache(bsz, block_size, max_num_blocks_per_seq, num_kv_heads, same_context_len)
run_decode_copy_kv_to_caches(bsz, block_size, max_num_blocks_per_seq, num_kv_heads, same_context_len)
if __name__ == "__main__":
test_kv_cache_memcopy(4, 32, 8, 16, True)

View File

@@ -0,0 +1,51 @@
import pytest
import torch
from transformers.models.llama.modeling_llama import LlamaRMSNorm
from colossalai.kernel.kernel_loader import InferenceOpsLoader
from colossalai.utils import get_current_device
inference_ops = InferenceOpsLoader().load()
@pytest.mark.parametrize("M", [2, 4, 8, 16])
@pytest.mark.parametrize("N", [64, 128, 512, 5120])
def test_rms_layernorm(M: int, N: int):
torch.manual_seed(123)
torch.cuda.empty_cache()
torch.cuda.synchronize()
torch.cuda.reset_peak_memory_stats()
device = get_current_device()
dtype = torch.float16
eps = 1e-5
x_shape = (M, N)
w_shape = (x_shape[-1],)
weight = torch.ones(w_shape, dtype=dtype, device=device)
residual = torch.rand(x_shape, dtype=dtype, device=device)
residual_copy = residual.clone()
rms_norm = LlamaRMSNorm(hidden_size=N, eps=eps).cuda()
x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda")
x_copy = x.clone()
y_cuda = torch.empty_like(x)
inference_ops.rms_layernorm(y_cuda, x, weight, eps)
y_llama = rms_norm.forward(x).to(dtype)
assert y_cuda.shape == y_llama.shape
assert torch.allclose(y_cuda, y_llama, atol=1e-5, rtol=1e-3)
inference_ops.fused_add_rms_layernorm(x, residual, weight, eps)
y_cuda = x
x = x_copy + residual_copy
y_llama = rms_norm.forward(x).to(dtype)
assert y_cuda.shape == y_llama.shape
assert torch.allclose(y_cuda, y_llama, atol=1e-5, rtol=1e-3)
assert torch.allclose(x, residual, atol=1e-5, rtol=1e-3)
if __name__ == "__main__":
test_rms_layernorm(16, 5120)

View File

@@ -0,0 +1,130 @@
import numpy as np
import pytest
import torch
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb
from colossalai.kernel.kernel_loader import InferenceOpsLoader
inference_ops = InferenceOpsLoader().load()
from tests.test_infer.test_kernels.triton.kernel_utils import mock_alloc_block_table_and_kvcache_v3
from tests.test_infer.test_kernels.triton.test_rotary_embdding_unpad import torch_rotary_emb
def numpy_allclose(x, y, rtol, atol):
x_numpy = x.detach().cpu().numpy()
y_numpy = y.detach().cpu().numpy()
np.testing.assert_allclose(x_numpy, y_numpy, rtol=rtol, atol=atol)
@pytest.mark.parametrize("BATCH_SIZE", [4])
@pytest.mark.parametrize("SEQ_LEN", [64])
@pytest.mark.parametrize("H", [32])
@pytest.mark.parametrize("K_H", [16, 32])
@pytest.mark.parametrize("D", [64])
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32])
def test_rotary_emb(BATCH_SIZE, SEQ_LEN, H, K_H, D, dtype):
torch.manual_seed(10)
TOTAL_TOKENS = BATCH_SIZE * SEQ_LEN
# our crafted op equals to Transformers
x0 = torch.randn(TOTAL_TOKENS, SEQ_LEN, D, dtype=dtype)
x1 = torch.randn(TOTAL_TOKENS, SEQ_LEN, D, dtype=dtype)
emb = LlamaRotaryEmbedding(D)
cos, sin = emb(x0, TOTAL_TOKENS)
cos_2 = cos[:, : D // 2]
sin_2 = sin[:, : D // 2]
position_ids = torch.arange(TOTAL_TOKENS)
embd_x0, _ = apply_rotary_pos_emb(x0, x1, cos, sin, position_ids)
embd_stimulated_x = torch_rotary_emb(x0, cos_2, sin_2)
assert torch.allclose(embd_x0, embd_stimulated_x)
# create data
block_size = 32
max_blocks_per_sequence = (TOTAL_TOKENS + block_size - 1) // block_size
q_shape = (TOTAL_TOKENS, H, D)
q = -2.3 + 0.5 * torch.randn(q_shape, dtype=dtype, device="cuda")
k_shape = (TOTAL_TOKENS, K_H, D)
k = -2.3 + 0.5 * torch.randn(k_shape, dtype=dtype, device="cuda")
cos_shape = (TOTAL_TOKENS, D // 2)
cos = -1.2 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
sin = -2.0 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
x = 16 // torch.tensor([], dtype=dtype).element_size()
k_cache_shape = (BATCH_SIZE * max_blocks_per_sequence, K_H, D // x, block_size, x)
v_cache_shape = (BATCH_SIZE * max_blocks_per_sequence, K_H, block_size, D)
k_cache = torch.zeros(size=k_cache_shape, dtype=dtype, device="cuda")
v = torch.randn_like(k)
v_cache = torch.zeros(size=v_cache_shape, dtype=dtype, device="cuda")
past_kv_seq_lengths = torch.tensor([SEQ_LEN - 1 for _ in range(BATCH_SIZE)], dtype=torch.int32, device="cuda")
block_tables = mock_alloc_block_table_and_kvcache_v3(
k, v, k_cache, v_cache, past_kv_seq_lengths, BATCH_SIZE, max_blocks_per_sequence, block_size
)
new_k = torch.randn((BATCH_SIZE, K_H, D), dtype=dtype, device="cuda")
new_q = torch.randn((BATCH_SIZE, H, D), dtype=dtype, device="cuda")
new_v = torch.randn_like(new_k)
kv_seq_lengths = past_kv_seq_lengths + 1
block_tables = block_tables.to(device="cuda")
new_q_copy = new_q.clone()
new_k_copy = new_k.clone()
if dtype == torch.float16:
rtol = 1e-3
atol = 1e-3
new_q_fp16 = new_q.clone()
new_k_fp16 = new_k.clone()
high_precision_cos = cos[:BATCH_SIZE].to(torch.float32)
high_precision_sin = sin[:BATCH_SIZE].to(torch.float32)
high_precision_q = new_q.to(torch.float32)
high_precision_k = new_k.to(torch.float32)
q_ref = torch_rotary_emb(high_precision_q, high_precision_cos, high_precision_sin).to(torch.float16)
k_ref = torch_rotary_emb(high_precision_k, high_precision_cos, high_precision_sin).to(torch.float16)
else:
rtol = 1e-5
atol = 1e-7
q_ref = torch_rotary_emb(new_q, cos[:BATCH_SIZE], sin[:BATCH_SIZE])
k_ref = torch_rotary_emb(new_k, cos[:BATCH_SIZE], sin[:BATCH_SIZE])
inference_ops.rotary_embedding_and_cache_copy(
new_q, new_k, new_v, cos, sin, k_cache, v_cache, kv_seq_lengths, block_tables, True
)
inference_ops.rotary_embedding(new_q_copy, new_k_copy, cos, sin, True)
past_kv_seq_len = kv_seq_lengths - 1
target_block_ids = block_tables[range(0, block_tables.size(0)), past_kv_seq_len // block_size]
offsets_in_block = past_kv_seq_len % block_size
k_target = k_cache[target_block_ids, :, :, offsets_in_block, :].squeeze()
k_source = new_k_copy.squeeze()
v_target = v_cache[target_block_ids, :, offsets_in_block, :].squeeze()
k_target = k_target.reshape(v_target.shape)
v_source = new_v.squeeze()
numpy_allclose(new_q, q_ref, rtol=rtol, atol=atol)
numpy_allclose(k_target, k_ref, rtol=rtol, atol=atol)
numpy_allclose(new_q_copy, q_ref, rtol=rtol, atol=atol)
numpy_allclose(new_k_copy, k_ref, rtol=rtol, atol=atol)
assert k_target.shape == k_source.shape
numpy_allclose(k_target, k_source, rtol=rtol, atol=atol)
assert v_target.shape == v_source.shape
assert torch.equal(v_target, v_source)
if dtype == torch.float16:
# After testing cuda fp16 high_precision, it was found to have higher precision than torch fp16. Therefore, the threshold here has been relaxed to pass the test.
rtol = 1e-3
atol = 1e-1
inference_ops.rotary_embedding(new_q_fp16, new_k_fp16, cos, sin, False)
numpy_allclose(new_q_copy, new_q_fp16, rtol=rtol, atol=atol)
numpy_allclose(new_k_copy, new_k_fp16, rtol=rtol, atol=atol)
if __name__ == "__main__":
test_rotary_emb(16, 64, 32, 16, 128, torch.float16)

View File

@@ -0,0 +1,33 @@
import pytest
import torch
from colossalai.kernel.kernel_loader import InferenceOpsLoader
from colossalai.utils import get_current_device
inference_ops = InferenceOpsLoader().load()
@pytest.mark.parametrize("SHAPE_X", [2])
@pytest.mark.parametrize("SHAPE_Y", [64])
@pytest.mark.parametrize("SHAPE_Z", [11008])
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16])
def test_silu_and_mul(SHAPE_X, SHAPE_Y, SHAPE_Z, dtype):
torch.manual_seed(5)
device = get_current_device()
ref_input = torch.randn(SHAPE_X, SHAPE_Y, SHAPE_Z, dtype=dtype, device=device)
origin_input = ref_input.clone()
act_out = torch.nn.functional.silu(ref_input[0], inplace=True)
ref_out = act_out * ref_input[1]
origin_out = inference_ops.silu_and_mul(origin_input)
if dtype == torch.float32:
assert torch.allclose(origin_out, ref_out, atol=1e-5, rtol=1e-5)
else:
assert torch.allclose(origin_out, ref_out, atol=1e-3, rtol=1e-3)
if __name__ == "__main__":
test_silu_and_mul(2, 64, 11008, torch.float32)
test_silu_and_mul(2, 64, 11008, torch.float16)