[Fix] Fix & Update Inference Tests (compatibility w/ main)

This commit is contained in:
Yuanheng Zhao
2024-05-05 16:28:56 +00:00
parent 56ed09aba5
commit 8754abae24
30 changed files with 32 additions and 30 deletions

View File

@@ -0,0 +1,320 @@
from itertools import product
import numpy as np
import pytest
import torch
from colossalai.inference.modeling.models.nopadding_baichuan import get_alibi_slopes
from colossalai.kernel.kernel_loader import InferenceOpsLoader
from colossalai.utils import get_current_device
from tests.test_infer.test_kernels.triton.test_context_attn_unpad import generate_alibi_mask
inference_ops = InferenceOpsLoader().load()
from tests.test_infer.test_kernels.triton.kernel_utils import (
convert_kv_unpad_to_padded,
create_attention_mask,
generate_caches_and_block_tables_v3,
generate_caches_and_block_tables_vllm,
torch_attn_ref,
)
q_len = 1
def prepare_data(
BATCH_SIZE: int,
HEAD_SIZE: int,
NUM_ATTN_HEADS: int,
NUM_KV_HEADS: int,
MAX_SEQ_LEN: int,
dtype=torch.float16,
device="cuda",
):
# Use the provided maximum sequence length for each sequence when testing with teh same context length,
# otherwise generate random context lengths.
# returns
# q [BATCH_SIZE, NUM_ATTN_HEADS, HEAD_SIZE]
# k_unpad/v_unpad [num_tokens, NUM_KV_HEADS, HEAD_SIZE]
kv_lengths = torch.randint(low=1, high=MAX_SEQ_LEN, size=(BATCH_SIZE,), dtype=torch.int32, device=device)
num_tokens = torch.sum(kv_lengths).item()
q_size = (BATCH_SIZE, q_len, NUM_ATTN_HEADS, HEAD_SIZE)
q = torch.empty(size=q_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5).transpose(1, 2)
kv_size = (num_tokens, 2 * NUM_KV_HEADS, HEAD_SIZE)
kv_unpad = torch.empty(size=kv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5)
k_unpad, v_unpad = torch.split(kv_unpad, [NUM_KV_HEADS, NUM_KV_HEADS], dim=-2)
return q, k_unpad, v_unpad, kv_lengths
def numpy_allclose(x, y, rtol, atol):
x_numpy = x.detach().cpu().numpy()
y_numpy = y.detach().cpu().numpy()
np.testing.assert_allclose(x_numpy, y_numpy, rtol=rtol, atol=atol)
@pytest.mark.parametrize("BATCH_SIZE", [1, 4, 7, 32])
@pytest.mark.parametrize("BLOCK_SIZE", [8, 16, 32])
@pytest.mark.parametrize("MAX_NUM_BLOCKS_PER_SEQ", [1, 8, 32])
@pytest.mark.parametrize("HEAD_SIZE", [64, 128])
@pytest.mark.parametrize("NUM_ATTN_HEADS", [16])
@pytest.mark.parametrize("KV_GROUP_NUM", [1, 2, 16])
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32])
@pytest.mark.parametrize("use_alibi_slopes", [True, False])
def test_flash_decoding_attention(
BATCH_SIZE, BLOCK_SIZE, MAX_NUM_BLOCKS_PER_SEQ, HEAD_SIZE, NUM_ATTN_HEADS, KV_GROUP_NUM, dtype, use_alibi_slopes
):
torch.manual_seed(123)
torch.cuda.empty_cache()
torch.cuda.synchronize()
torch.cuda.reset_peak_memory_stats()
NUM_KV_HEADS = NUM_ATTN_HEADS // KV_GROUP_NUM
assert isinstance(NUM_KV_HEADS, int) and NUM_KV_HEADS > 0, "Invalid number of kv heads."
MAX_SEQ_LEN = BLOCK_SIZE * MAX_NUM_BLOCKS_PER_SEQ
device = get_current_device()
if use_alibi_slopes:
alibi_slopes = get_alibi_slopes(NUM_ATTN_HEADS, device)
else:
alibi_slopes = None
q, k_unpad, v_unpad, kv_seq_lengths = prepare_data(
BATCH_SIZE, HEAD_SIZE, NUM_ATTN_HEADS, NUM_KV_HEADS, MAX_SEQ_LEN, dtype, device
)
k_cache, v_cache, block_tables = generate_caches_and_block_tables_v3(
k_unpad, v_unpad, kv_seq_lengths, BATCH_SIZE, MAX_NUM_BLOCKS_PER_SEQ, BLOCK_SIZE, dtype, device
)
block_tables = block_tables.to(device=device)
max_seq_len_across_batch = kv_seq_lengths.max().item()
kv_max_split_num = (max_seq_len_across_batch + BLOCK_SIZE - 1) // BLOCK_SIZE
output = torch.empty((BATCH_SIZE, NUM_ATTN_HEADS, HEAD_SIZE), dtype=dtype, device=device)
sm_scale = 1.0 / (HEAD_SIZE**0.5)
k_torch = convert_kv_unpad_to_padded(k_unpad, kv_seq_lengths, BATCH_SIZE, max_seq_len_across_batch)
v_torch = convert_kv_unpad_to_padded(v_unpad, kv_seq_lengths, BATCH_SIZE, max_seq_len_across_batch)
torch_padding_mask = create_attention_mask(kv_seq_lengths, BATCH_SIZE, q_len, max_seq_len_across_batch, device)
if use_alibi_slopes:
alibi_mask = generate_alibi_mask(alibi_slopes, NUM_ATTN_HEADS, max_seq_len_across_batch, device)
torch_padding_mask = torch_padding_mask + alibi_mask
if len(torch_padding_mask.size()) == 4:
torch_padding_mask = torch_padding_mask[:, :, -1:, :]
else:
torch_padding_mask = torch_padding_mask[:, -1:, :]
mid_output = torch.empty(
size=(BATCH_SIZE, NUM_ATTN_HEADS, kv_max_split_num, HEAD_SIZE), dtype=torch.float32, device=device
)
mid_output_lse = torch.empty(
size=(BATCH_SIZE, NUM_ATTN_HEADS, kv_max_split_num), dtype=torch.float32, device=device
)
if dtype == torch.float16:
rtol = 1e-3
atol = 1e-3
high_precision_q = q.to(torch.float32)
high_precision_k_torch = k_torch.to(torch.float32)
high_precision_v_torch = v_torch.to(torch.float32)
out_ref = torch_attn_ref(
high_precision_q,
high_precision_k_torch,
high_precision_v_torch,
torch_padding_mask,
BATCH_SIZE,
q_len,
max_seq_len_across_batch,
NUM_ATTN_HEADS,
NUM_KV_HEADS,
HEAD_SIZE,
).to(torch.float16)
else:
rtol = 1e-5
atol = 1e-7
out_ref = torch_attn_ref(
q,
k_torch,
v_torch,
torch_padding_mask,
BATCH_SIZE,
q_len,
max_seq_len_across_batch,
NUM_ATTN_HEADS,
NUM_KV_HEADS,
HEAD_SIZE,
)
inference_ops.flash_decoding_attention(
output,
q.squeeze(2),
k_cache,
v_cache,
kv_seq_lengths,
block_tables,
BLOCK_SIZE,
max_seq_len_across_batch,
mid_output,
mid_output_lse,
alibi_slopes,
sm_scale,
)
# The alibi may introduce relatively large errors
if use_alibi_slopes:
rtol = 1e0
numpy_allclose(out_ref, output, rtol=rtol, atol=atol)
try:
from vllm._C import ops as vllm_ops # noqa
HAS_VLLM = True
except ImportError:
HAS_VLLM = False
print("The subsequent test requires vllm. Please refer to https://github.com/vllm-project/vllm")
@pytest.mark.skipif(not HAS_VLLM, reason="requires vllm")
@pytest.mark.parametrize("BATCH_SIZE", [1, 4, 7, 32])
@pytest.mark.parametrize("BLOCK_SIZE", [8, 16, 32])
@pytest.mark.parametrize("MAX_NUM_BLOCKS_PER_SEQ", [1, 8, 32])
@pytest.mark.parametrize("HEAD_SIZE", [64, 128])
@pytest.mark.parametrize("NUM_ATTN_HEADS", [16])
@pytest.mark.parametrize("KV_GROUP_NUM", [1, 2, 16])
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32])
@pytest.mark.parametrize("use_alibi_slopes", [True, False])
def test_vllm_flash_decoding_attention(
BATCH_SIZE, BLOCK_SIZE, MAX_NUM_BLOCKS_PER_SEQ, HEAD_SIZE, NUM_ATTN_HEADS, KV_GROUP_NUM, dtype, use_alibi_slopes
):
torch.manual_seed(123)
torch.cuda.empty_cache()
torch.cuda.synchronize()
torch.cuda.reset_peak_memory_stats()
NUM_KV_HEADS = NUM_ATTN_HEADS // KV_GROUP_NUM
assert isinstance(NUM_KV_HEADS, int) and NUM_KV_HEADS > 0, "Invalid number of kv heads."
MAX_SEQ_LEN = BLOCK_SIZE * MAX_NUM_BLOCKS_PER_SEQ
device = get_current_device()
q, k_unpad, v_unpad, kv_seq_lengths = prepare_data(
BATCH_SIZE, HEAD_SIZE, NUM_ATTN_HEADS, NUM_KV_HEADS, MAX_SEQ_LEN, dtype, device
)
k_cache, v_cache, block_tables = generate_caches_and_block_tables_vllm(
k_unpad, v_unpad, kv_seq_lengths, BATCH_SIZE, MAX_NUM_BLOCKS_PER_SEQ, BLOCK_SIZE, dtype, device
)
block_tables = block_tables.to(device=device)
max_seq_len_across_batch = kv_seq_lengths.max().item()
output = torch.empty((BATCH_SIZE, NUM_ATTN_HEADS, HEAD_SIZE), dtype=dtype, device=device)
sm_scale = 1.0 / (HEAD_SIZE**0.5)
kv_scale = 1.0
k_torch = convert_kv_unpad_to_padded(k_unpad, kv_seq_lengths, BATCH_SIZE, max_seq_len_across_batch)
v_torch = convert_kv_unpad_to_padded(v_unpad, kv_seq_lengths, BATCH_SIZE, max_seq_len_across_batch)
torch_padding_mask = create_attention_mask(kv_seq_lengths, BATCH_SIZE, q_len, max_seq_len_across_batch, device)
if use_alibi_slopes:
alibi_slopes = get_alibi_slopes(NUM_ATTN_HEADS, device)
alibi_mask = generate_alibi_mask(alibi_slopes, NUM_ATTN_HEADS, max_seq_len_across_batch, device)
torch_padding_mask = torch_padding_mask + alibi_mask
if len(torch_padding_mask.size()) == 4:
torch_padding_mask = torch_padding_mask[:, :, -1:, :]
else:
torch_padding_mask = torch_padding_mask[:, -1:, :]
else:
alibi_slopes = None
if dtype == torch.float16:
rtol = 1e-3
atol = 1e-3
high_precision_q = q.to(torch.float32)
high_precision_k_torch = k_torch.to(torch.float32)
high_precision_v_torch = v_torch.to(torch.float32)
out_ref = torch_attn_ref(
high_precision_q,
high_precision_k_torch,
high_precision_v_torch,
torch_padding_mask,
BATCH_SIZE,
q_len,
max_seq_len_across_batch,
NUM_ATTN_HEADS,
NUM_KV_HEADS,
HEAD_SIZE,
).to(torch.float16)
else:
rtol = 1e-5
atol = 1e-7
out_ref = torch_attn_ref(
q,
k_torch,
v_torch,
torch_padding_mask,
BATCH_SIZE,
q_len,
max_seq_len_across_batch,
NUM_ATTN_HEADS,
NUM_KV_HEADS,
HEAD_SIZE,
)
vllm_ops.paged_attention_v1(
output,
q.squeeze(2),
k_cache,
v_cache,
NUM_KV_HEADS,
sm_scale,
block_tables,
kv_seq_lengths,
BLOCK_SIZE,
max_seq_len_across_batch,
alibi_slopes,
"auto",
kv_scale,
)
# The alibi may introduce relatively large errors
if use_alibi_slopes:
rtol = 1e0
numpy_allclose(out_ref, output, rtol=rtol, atol=atol)
if __name__ == "__main__":
BATCH_SIZE = [1, 4, 7, 32]
BLOCK_SIZE = [8, 16, 32]
MAX_NUM_BLOCKS_PER_SEQ = [1, 8, 32]
HEAD_SIZE = [64, 128]
NUM_ATTN_HEADS = [16]
KV_GROUP_NUM = [1, 2, 16]
DTYPE = [torch.float16, torch.float32]
test_combinations = list(
product(BATCH_SIZE, BLOCK_SIZE, MAX_NUM_BLOCKS_PER_SEQ, HEAD_SIZE, NUM_ATTN_HEADS, KV_GROUP_NUM, DTYPE)
)
for (
batch_size,
block_size,
max_num_blocks_per_seq,
head_size,
num_attn_heads,
kv_group_num,
dtype,
) in test_combinations:
test_flash_decoding_attention(
batch_size, block_size, max_num_blocks_per_seq, head_size, num_attn_heads, kv_group_num, dtype, True
)

View File

@@ -0,0 +1,53 @@
import numpy as np
import pytest
import torch
from colossalai.kernel.kernel_loader import InferenceOpsLoader
from tests.test_infer.test_kernels.triton.test_xine_copy import get_cos_sin
inference_ops = InferenceOpsLoader().load()
def numpy_equal(x, y):
x_numpy = x.detach().cpu().numpy()
y_numpy = y.detach().cpu().numpy()
np.testing.assert_equal(x_numpy, y_numpy)
@pytest.mark.parametrize("BATCH_SIZE", [4])
@pytest.mark.parametrize("MAX_SEQ_LEN", [64])
@pytest.mark.parametrize("HEAD_DIM", [64])
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32])
def test_get_cos_and_sin(BATCH_SIZE, MAX_SEQ_LEN, HEAD_DIM, dtype):
MAX_TOTAL_TOKENS = BATCH_SIZE * MAX_SEQ_LEN
cos_cache = torch.randn((MAX_TOTAL_TOKENS, HEAD_DIM), dtype=dtype, device="cuda")
sin_cache = torch.randn((MAX_TOTAL_TOKENS, HEAD_DIM), dtype=dtype, device="cuda")
lengths = torch.randint(2, MAX_SEQ_LEN, (BATCH_SIZE,), device="cuda").to(torch.int32)
max_seq_len_in_batch = lengths.max()
# prefill
cos_ref, sin_ref = get_cos_sin(lengths, cos_cache, sin_cache, is_prompts=True, dtype=dtype)
cos = torch.zeros_like(cos_ref)
sin = torch.zeros_like(sin_ref)
inference_ops.get_cos_and_sin(cos_cache, sin_cache, cos, sin, lengths, max_seq_len_in_batch, True)
numpy_equal(cos, cos_ref)
numpy_equal(sin, sin_ref)
# decoding
ncos_ref, nsin_ref = get_cos_sin(lengths, cos_cache, sin_cache, is_prompts=False, dtype=dtype)
cos = torch.zeros_like(ncos_ref)
sin = torch.zeros_like(nsin_ref)
inference_ops.get_cos_and_sin(cos_cache, sin_cache, cos, sin, lengths, max_seq_len_in_batch, False)
numpy_equal(cos, ncos_ref)
numpy_equal(sin, nsin_ref)
if __name__ == "__main__":
test_get_cos_and_sin(16, 4096, 256, torch.float16)

View File

@@ -0,0 +1,157 @@
import pytest
import torch
import torch.nn.functional as F
from colossalai.kernel.kernel_loader import InferenceOpsLoader
from colossalai.utils import get_current_device
from tests.test_infer.test_kernels.triton.kernel_utils import (
generate_caches_and_block_tables_v3,
mock_alloc_single_token,
)
inference_ops = InferenceOpsLoader().load()
HEAD_DIM = 72
def prepare_data(
bsz,
num_kv_heads,
block_size,
max_num_blocks_per_seq,
context_lengths,
device="cuda",
dtype=torch.float16,
):
num_tokens = torch.sum(context_lengths).item()
max_seq_len_in_batch = context_lengths.max()
cu_seqlens = F.pad(torch.cumsum(context_lengths, dim=0, dtype=torch.torch.int32), (1, 0))
kv_size = (num_tokens, num_kv_heads, HEAD_DIM)
key = torch.empty(size=kv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5)
value = torch.empty(size=kv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5)
k_cache_ref, v_cache_ref, block_tables = generate_caches_and_block_tables_v3(
key, value, context_lengths, bsz, max_num_blocks_per_seq, block_size, dtype, device
)
block_tables = block_tables.to(device=device)
k_cache = torch.zeros_like(k_cache_ref)
v_cache = torch.zeros_like(v_cache_ref)
return key, value, k_cache, v_cache, cu_seqlens, block_tables, max_seq_len_in_batch, k_cache_ref, v_cache_ref
def run_decode_copy_kv_to_caches(
bsz: int,
block_size: int,
max_num_blocks_per_seq: int,
num_kv_heads: int,
same_context_len: bool,
):
torch.manual_seed(123)
torch.cuda.empty_cache()
torch.cuda.synchronize()
torch.cuda.reset_peak_memory_stats()
n = 1
max_seq_len = block_size * max_num_blocks_per_seq
dtype = torch.float32
device = get_current_device()
assert max_seq_len > n, "max_seq_len must be greater than n"
past_kv_seq_lengths = (
torch.tensor([max_seq_len - n for _ in range(bsz)], dtype=torch.int32, device=device)
if same_context_len
else torch.randint(low=1, high=max_seq_len - n, size=(bsz,), dtype=torch.int32, device=device)
)
key, value, k_cache, v_cache, _, block_tables, _, _, _ = prepare_data(
bsz, num_kv_heads, block_size, max_num_blocks_per_seq, past_kv_seq_lengths, device, dtype
)
new_k = torch.randn((bsz, num_kv_heads, HEAD_DIM), dtype=dtype, device=device)
new_v = torch.randn((bsz, num_kv_heads, HEAD_DIM), dtype=dtype, device=device)
# mock allocating blocks for the new k/v and update block tables
for _ in range(n):
mock_alloc_single_token(block_tables, past_kv_seq_lengths, block_size)
past_kv_seq_lengths += 1
inference_ops.decode_kv_cache_memcpy(new_k, new_v, k_cache, v_cache, past_kv_seq_lengths, block_tables)
past_kv_seq_len = past_kv_seq_lengths - 1
target_block_ids = block_tables[range(0, block_tables.size(0)), past_kv_seq_len // block_size]
offsets_in_block = past_kv_seq_len % block_size
k_target = k_cache[target_block_ids, :, :, offsets_in_block, :]
k_source = new_k.squeeze()
v_target = v_cache[target_block_ids, :, offsets_in_block, :]
k_target = k_target.reshape(v_target.shape)
v_source = new_v.squeeze()
assert k_target.shape == k_source.shape
assert torch.equal(k_target, k_source)
assert v_target.shape == v_source.shape
assert torch.equal(v_target, v_source)
def run_context_copy_kv_to_cache(
bsz: int,
block_size: int,
max_num_blocks_per_seq: int,
num_kv_heads: int,
same_context_len: bool,
):
torch.manual_seed(123)
assert isinstance(num_kv_heads, int) and num_kv_heads > 0, "Invalid number of kv heads."
max_seq_len = max_num_blocks_per_seq * block_size
dtype = torch.float16
device = get_current_device()
if same_context_len:
context_lengths = torch.tensor([max_seq_len for _ in range(bsz)], dtype=torch.int32, device=device)
else:
context_lengths = torch.randint(low=1, high=max_seq_len, size=(bsz,), dtype=torch.int32, device=device)
(
key,
value,
k_cache,
v_cache,
cu_seqlens,
block_tables,
max_seq_len_in_batch,
k_cache_ref,
v_cache_ref,
) = prepare_data(bsz, num_kv_heads, block_size, max_num_blocks_per_seq, context_lengths, device, dtype)
inference_ops.context_kv_cache_memcpy(
key, value, k_cache, v_cache, context_lengths, cu_seqlens, block_tables, max_seq_len_in_batch
)
assert torch.equal(k_cache, k_cache_ref)
assert torch.equal(v_cache, v_cache_ref)
@pytest.mark.parametrize("bsz", [4, 7, 32])
@pytest.mark.parametrize("block_size", [16, 32, 64])
@pytest.mark.parametrize("max_num_blocks_per_seq", [8, 32])
@pytest.mark.parametrize("num_kv_heads", [16])
@pytest.mark.parametrize("same_context_len", [True, False])
def test_kv_cache_memcopy(
bsz: int,
block_size: int,
max_num_blocks_per_seq: int,
num_kv_heads: int,
same_context_len: bool,
):
run_context_copy_kv_to_cache(bsz, block_size, max_num_blocks_per_seq, num_kv_heads, same_context_len)
run_decode_copy_kv_to_caches(bsz, block_size, max_num_blocks_per_seq, num_kv_heads, same_context_len)
if __name__ == "__main__":
test_kv_cache_memcopy(4, 32, 8, 16, True)

View File

@@ -0,0 +1,51 @@
import pytest
import torch
from transformers.models.llama.modeling_llama import LlamaRMSNorm
from colossalai.kernel.kernel_loader import InferenceOpsLoader
from colossalai.utils import get_current_device
inference_ops = InferenceOpsLoader().load()
@pytest.mark.parametrize("M", [2, 4, 8, 16])
@pytest.mark.parametrize("N", [64, 128, 512, 5120])
def test_rms_layernorm(M: int, N: int):
torch.manual_seed(123)
torch.cuda.empty_cache()
torch.cuda.synchronize()
torch.cuda.reset_peak_memory_stats()
device = get_current_device()
dtype = torch.float16
eps = 1e-5
x_shape = (M, N)
w_shape = (x_shape[-1],)
weight = torch.ones(w_shape, dtype=dtype, device=device)
residual = torch.rand(x_shape, dtype=dtype, device=device)
residual_copy = residual.clone()
rms_norm = LlamaRMSNorm(hidden_size=N, eps=eps).cuda()
x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda")
x_copy = x.clone()
y_cuda = torch.empty_like(x)
inference_ops.rms_layernorm(y_cuda, x, weight, eps)
y_llama = rms_norm.forward(x).to(dtype)
assert y_cuda.shape == y_llama.shape
assert torch.allclose(y_cuda, y_llama, atol=1e-5, rtol=1e-3)
inference_ops.fused_add_rms_layernorm(x, residual, weight, eps)
y_cuda = x
x = x_copy + residual_copy
y_llama = rms_norm.forward(x).to(dtype)
assert y_cuda.shape == y_llama.shape
assert torch.allclose(y_cuda, y_llama, atol=1e-5, rtol=1e-3)
assert torch.allclose(x, residual, atol=1e-5, rtol=1e-3)
if __name__ == "__main__":
test_rms_layernorm(16, 5120)

View File

@@ -0,0 +1,130 @@
import numpy as np
import pytest
import torch
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb
from colossalai.kernel.kernel_loader import InferenceOpsLoader
inference_ops = InferenceOpsLoader().load()
from tests.test_infer.test_kernels.triton.kernel_utils import mock_alloc_block_table_and_kvcache_v3
from tests.test_infer.test_kernels.triton.test_rotary_embdding_unpad import torch_rotary_emb
def numpy_allclose(x, y, rtol, atol):
x_numpy = x.detach().cpu().numpy()
y_numpy = y.detach().cpu().numpy()
np.testing.assert_allclose(x_numpy, y_numpy, rtol=rtol, atol=atol)
@pytest.mark.parametrize("BATCH_SIZE", [4])
@pytest.mark.parametrize("SEQ_LEN", [64])
@pytest.mark.parametrize("H", [32])
@pytest.mark.parametrize("K_H", [16, 32])
@pytest.mark.parametrize("D", [64])
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32])
def test_rotary_emb(BATCH_SIZE, SEQ_LEN, H, K_H, D, dtype):
torch.manual_seed(10)
TOTAL_TOKENS = BATCH_SIZE * SEQ_LEN
# our crafted op equals to Transformers
x0 = torch.randn(TOTAL_TOKENS, SEQ_LEN, D, dtype=dtype)
x1 = torch.randn(TOTAL_TOKENS, SEQ_LEN, D, dtype=dtype)
emb = LlamaRotaryEmbedding(D)
cos, sin = emb(x0, TOTAL_TOKENS)
cos_2 = cos[:, : D // 2]
sin_2 = sin[:, : D // 2]
position_ids = torch.arange(TOTAL_TOKENS)
embd_x0, _ = apply_rotary_pos_emb(x0, x1, cos, sin, position_ids)
embd_stimulated_x = torch_rotary_emb(x0, cos_2, sin_2)
assert torch.allclose(embd_x0, embd_stimulated_x)
# create data
block_size = 32
max_blocks_per_sequence = (TOTAL_TOKENS + block_size - 1) // block_size
q_shape = (TOTAL_TOKENS, H, D)
q = -2.3 + 0.5 * torch.randn(q_shape, dtype=dtype, device="cuda")
k_shape = (TOTAL_TOKENS, K_H, D)
k = -2.3 + 0.5 * torch.randn(k_shape, dtype=dtype, device="cuda")
cos_shape = (TOTAL_TOKENS, D // 2)
cos = -1.2 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
sin = -2.0 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
x = 16 // torch.tensor([], dtype=dtype).element_size()
k_cache_shape = (BATCH_SIZE * max_blocks_per_sequence, K_H, D // x, block_size, x)
v_cache_shape = (BATCH_SIZE * max_blocks_per_sequence, K_H, block_size, D)
k_cache = torch.zeros(size=k_cache_shape, dtype=dtype, device="cuda")
v = torch.randn_like(k)
v_cache = torch.zeros(size=v_cache_shape, dtype=dtype, device="cuda")
past_kv_seq_lengths = torch.tensor([SEQ_LEN - 1 for _ in range(BATCH_SIZE)], dtype=torch.int32, device="cuda")
block_tables = mock_alloc_block_table_and_kvcache_v3(
k, v, k_cache, v_cache, past_kv_seq_lengths, BATCH_SIZE, max_blocks_per_sequence, block_size
)
new_k = torch.randn((BATCH_SIZE, K_H, D), dtype=dtype, device="cuda")
new_q = torch.randn((BATCH_SIZE, H, D), dtype=dtype, device="cuda")
new_v = torch.randn_like(new_k)
kv_seq_lengths = past_kv_seq_lengths + 1
block_tables = block_tables.to(device="cuda")
new_q_copy = new_q.clone()
new_k_copy = new_k.clone()
if dtype == torch.float16:
rtol = 1e-3
atol = 1e-3
new_q_fp16 = new_q.clone()
new_k_fp16 = new_k.clone()
high_precision_cos = cos[:BATCH_SIZE].to(torch.float32)
high_precision_sin = sin[:BATCH_SIZE].to(torch.float32)
high_precision_q = new_q.to(torch.float32)
high_precision_k = new_k.to(torch.float32)
q_ref = torch_rotary_emb(high_precision_q, high_precision_cos, high_precision_sin).to(torch.float16)
k_ref = torch_rotary_emb(high_precision_k, high_precision_cos, high_precision_sin).to(torch.float16)
else:
rtol = 1e-5
atol = 1e-7
q_ref = torch_rotary_emb(new_q, cos[:BATCH_SIZE], sin[:BATCH_SIZE])
k_ref = torch_rotary_emb(new_k, cos[:BATCH_SIZE], sin[:BATCH_SIZE])
inference_ops.rotary_embedding_and_cache_copy(
new_q, new_k, new_v, cos, sin, k_cache, v_cache, kv_seq_lengths, block_tables, True
)
inference_ops.rotary_embedding(new_q_copy, new_k_copy, cos, sin, True)
past_kv_seq_len = kv_seq_lengths - 1
target_block_ids = block_tables[range(0, block_tables.size(0)), past_kv_seq_len // block_size]
offsets_in_block = past_kv_seq_len % block_size
k_target = k_cache[target_block_ids, :, :, offsets_in_block, :].squeeze()
k_source = new_k_copy.squeeze()
v_target = v_cache[target_block_ids, :, offsets_in_block, :].squeeze()
k_target = k_target.reshape(v_target.shape)
v_source = new_v.squeeze()
numpy_allclose(new_q, q_ref, rtol=rtol, atol=atol)
numpy_allclose(k_target, k_ref, rtol=rtol, atol=atol)
numpy_allclose(new_q_copy, q_ref, rtol=rtol, atol=atol)
numpy_allclose(new_k_copy, k_ref, rtol=rtol, atol=atol)
assert k_target.shape == k_source.shape
numpy_allclose(k_target, k_source, rtol=rtol, atol=atol)
assert v_target.shape == v_source.shape
assert torch.equal(v_target, v_source)
if dtype == torch.float16:
# After testing cuda fp16 high_precision, it was found to have higher precision than torch fp16. Therefore, the threshold here has been relaxed to pass the test.
rtol = 1e-3
atol = 1e-1
inference_ops.rotary_embedding(new_q_fp16, new_k_fp16, cos, sin, False)
numpy_allclose(new_q_copy, new_q_fp16, rtol=rtol, atol=atol)
numpy_allclose(new_k_copy, new_k_fp16, rtol=rtol, atol=atol)
if __name__ == "__main__":
test_rotary_emb(16, 64, 32, 16, 128, torch.float16)

View File

@@ -0,0 +1,33 @@
import pytest
import torch
from colossalai.kernel.kernel_loader import InferenceOpsLoader
from colossalai.utils import get_current_device
inference_ops = InferenceOpsLoader().load()
@pytest.mark.parametrize("SHAPE_X", [2])
@pytest.mark.parametrize("SHAPE_Y", [64])
@pytest.mark.parametrize("SHAPE_Z", [11008])
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16])
def test_silu_and_mul(SHAPE_X, SHAPE_Y, SHAPE_Z, dtype):
torch.manual_seed(5)
device = get_current_device()
ref_input = torch.randn(SHAPE_X, SHAPE_Y, SHAPE_Z, dtype=dtype, device=device)
origin_input = ref_input.clone()
act_out = torch.nn.functional.silu(ref_input[0], inplace=True)
ref_out = act_out * ref_input[1]
origin_out = inference_ops.silu_and_mul(origin_input)
if dtype == torch.float32:
assert torch.allclose(origin_out, ref_out, atol=1e-5, rtol=1e-5)
else:
assert torch.allclose(origin_out, ref_out, atol=1e-3, rtol=1e-3)
if __name__ == "__main__":
test_silu_and_mul(2, 64, 11008, torch.float32)
test_silu_and_mul(2, 64, 11008, torch.float16)

View File

@@ -0,0 +1,348 @@
from typing import Tuple
import torch
from torch.nn import functional as F
# This function is adapted from src/transformers/models/llama/modeling_llama.py
# in huggingface transformers repository
# https://github.com/huggingface/transformers/blob/3b7675b2b844b02d4821b827871a21ad16dd446c/src/transformers/models/llama/modeling_llama.py#L273
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep).
The hidden states go from (bsz, num_key_value_heads, seq_len, head_dim) to (bsz, num_attention_heads, seq_len, head_dim)
"""
if n_rep == 1:
return hidden_states
bsz, num_key_value_heads, seq_len, head_dim = hidden_states.shape
hidden_states = hidden_states[:, :, None, :, :].expand(bsz, num_key_value_heads, n_rep, seq_len, head_dim)
return hidden_states.reshape(bsz, num_key_value_heads * n_rep, seq_len, head_dim)
def create_attention_mask(kv_lengths: torch.Tensor, bsz: int, q_len: int, kv_len: int, device="cuda"):
assert q_len <= kv_len
causal_mask = torch.full((q_len, q_len), fill_value=float("-inf"), device=device).triu(diagonal=1)
padding_mask = torch.zeros((bsz, 1, q_len, kv_len), dtype=torch.float32, device=device)
for i in range(bsz):
cur_seq_len = kv_lengths[i].item()
assert cur_seq_len <= kv_len
padding_mask[i, :, :, : kv_len - cur_seq_len] = float("-inf")
padding_mask[:, :, -q_len:, -q_len:] += causal_mask
return padding_mask
# Attention calculation adapted from HuggingFace transformers repository
# src/transformers/models/llama/modeling_llama.py
# https://github.com/huggingface/transformers/blob/633215ba58fe5114d8c8d32e415a04600e010701/src/transformers/models/llama/modeling_llama.py#L350
def torch_attn_ref(
q: torch.Tensor, # [bsz, num_heads, q_len, head_dim]
k: torch.Tensor, # [bsz, num_heads, kv_len, head_dim]
v: torch.Tensor, # [bsz, num_heads, kv_len, head_dim]
attention_mask: torch.Tensor, # [bsz, 1, q_len, kv_len]
bsz: int,
q_len: int,
kv_len: int,
num_heads: int,
num_kv_heads: int,
head_dim: int,
) -> torch.Tensor:
assert q.shape[-1] == k.shape[-1] == v.shape[-1] == head_dim
# repeat kv for GQA and MQA
# k/v won't change if kv_group_num is 1
assert num_heads % num_kv_heads == 0, "Number of heads is not multiple of kv heads"
kv_group_num = num_heads // num_kv_heads
k = repeat_kv(k, kv_group_num)
v = repeat_kv(v, kv_group_num)
qk = torch.matmul(q, k.transpose(2, 3))
attn_scores = qk / (head_dim**0.5)
assert attn_scores.shape == (bsz, num_heads, q_len, kv_len), "Invalid shape of attention scores"
if attention_mask is not None:
attn_scores = attn_scores + attention_mask
attn_weights = F.softmax(attn_scores.to(dtype=torch.float32), dim=-1).to(dtype=q.dtype)
out = torch.matmul(attn_weights, v)
if out.size() != (bsz, num_heads, q_len, head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, num_heads, q_len, head_dim)}, but is" f" {out.size()}"
)
out = out.transpose(1, 2).contiguous()
out = out.view(-1, out.size(-2), out.size(-1))
# out [bsz * q_len, num_heads, head_dim]
return out
def mock_alloc_block_table_and_kvcache(
k: torch.Tensor,
v: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
context_lengths: torch.Tensor,
num_seqs: int,
max_num_blocks_per_seq: int,
block_size: int,
) -> torch.Tensor:
"""Allocate block tables based on provided context lengths; and copy KV to blocked KV Cache."""
block_id = 0
block_tables = torch.full(size=(num_seqs, max_num_blocks_per_seq), fill_value=-1, dtype=torch.int32)
num_tokens_processed = 0
for i, seq_len in enumerate(context_lengths.tolist()):
right_bound = (seq_len + block_size - 1) // block_size # open bound
block_tables[i, :right_bound] = torch.arange(block_id, block_id + right_bound, dtype=torch.int32)
# Manually fill kv caches by copying from k and v
for i in range(right_bound):
if i == right_bound - 1:
allocated_locs = seq_len % block_size or block_size
else:
allocated_locs = block_size
k_block = k[num_tokens_processed : num_tokens_processed + allocated_locs, :, :].permute(1, 2, 0)
v_block = v[num_tokens_processed : num_tokens_processed + allocated_locs, :, :].permute(1, 2, 0)
k_cache[block_id, :, :, :allocated_locs] = k_block
v_cache[block_id, :, :, :allocated_locs] = v_block
num_tokens_processed += allocated_locs
block_id += 1
return block_tables
def mock_alloc_block_table_and_kvcache_v2(
k: torch.Tensor,
v: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
context_lengths: torch.Tensor,
num_seqs: int,
max_num_blocks_per_seq: int,
block_size: int,
) -> torch.Tensor:
"""Allocate block tables based on provided context lengths; and copy KV to blocked KV Cache."""
block_id = 0
block_tables = torch.full(size=(num_seqs, max_num_blocks_per_seq), fill_value=-1, dtype=torch.int32)
num_tokens_processed = 0
for i, seq_len in enumerate(context_lengths.tolist()):
right_bound = (seq_len + block_size - 1) // block_size # open bound
block_tables[i, :right_bound] = torch.arange(block_id, block_id + right_bound, dtype=torch.int32)
# Manually fill kv caches by copying from k and v
for i in range(right_bound):
if i == right_bound - 1:
allocated_locs = seq_len % block_size or block_size
else:
allocated_locs = block_size
k_block = k[num_tokens_processed : num_tokens_processed + allocated_locs, :, :].permute(1, 0, 2)
v_block = v[num_tokens_processed : num_tokens_processed + allocated_locs, :, :].permute(1, 0, 2)
k_cache[block_id, :, :allocated_locs, :] = k_block
v_cache[block_id, :, :allocated_locs, :] = v_block
num_tokens_processed += allocated_locs
block_id += 1
return block_tables
def mock_alloc_block_table_and_kvcache_v3(
k: torch.Tensor,
v: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
context_lengths: torch.Tensor,
num_seqs: int,
max_num_blocks_per_seq: int,
block_size: int,
) -> torch.Tensor:
"""Allocate block tables based on provided context lengths; and copy KV to blocked KV Cache."""
block_id = 0
block_tables = torch.full(size=(num_seqs, max_num_blocks_per_seq), fill_value=-1, dtype=torch.int32)
num_tokens_processed = 0
_, num_kv_heads, head_dim = k.shape
x = 16 // torch.tensor([], dtype=k.dtype).element_size()
for i, seq_len in enumerate(context_lengths.tolist()):
right_bound = (seq_len + block_size - 1) // block_size # open bound
block_tables[i, :right_bound] = torch.arange(block_id, block_id + right_bound, dtype=torch.int32)
# Manually fill kv caches by copying from k and v
for i in range(right_bound):
if i == right_bound - 1:
allocated_locs = seq_len % block_size or block_size
else:
allocated_locs = block_size
# [block_size, num_kv_heads, head_dim/x, x]->[num_kv_heads, head_dim/x, block_size,x]
k_block = (
k[num_tokens_processed : num_tokens_processed + allocated_locs, :, :]
.reshape(allocated_locs, num_kv_heads, head_dim // x, x)
.permute(1, 2, 0, 3)
)
v_block = v[num_tokens_processed : num_tokens_processed + allocated_locs, :, :].permute(1, 0, 2)
k_cache[block_id, :, :, :allocated_locs, :] = k_block
v_cache[block_id, :, :allocated_locs, :] = v_block
num_tokens_processed += allocated_locs
block_id += 1
return block_tables
def mock_alloc_block_table_and_kvcache_vllm(
k: torch.Tensor,
v: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
context_lengths: torch.Tensor,
num_seqs: int,
max_num_blocks_per_seq: int,
block_size: int,
) -> torch.Tensor:
"""Allocate block tables based on provided context lengths; and copy KV to blocked KV Cache."""
block_id = 0
block_tables = torch.full(size=(num_seqs, max_num_blocks_per_seq), fill_value=-1, dtype=torch.int32)
num_tokens_processed = 0
_, num_kv_heads, head_dim = k.shape
x = 16 // torch.tensor([], dtype=k.dtype).element_size()
for i, seq_len in enumerate(context_lengths.tolist()):
right_bound = (seq_len + block_size - 1) // block_size # open bound
block_tables[i, :right_bound] = torch.arange(block_id, block_id + right_bound, dtype=torch.int32)
# Manually fill kv caches by copying from k and v
for i in range(right_bound):
if i == right_bound - 1:
allocated_locs = seq_len % block_size or block_size
else:
allocated_locs = block_size
# [block_size, num_kv_heads, head_dim/x, x]->[num_kv_heads, head_dim/x, block_size,x]
k_block = (
k[num_tokens_processed : num_tokens_processed + allocated_locs, :, :]
.reshape(allocated_locs, num_kv_heads, head_dim // x, x)
.permute(1, 2, 0, 3)
)
# [block_size, num_kv_heads, head_dim]->[num_kv_heads, head_dim, block_size]
v_block = v[num_tokens_processed : num_tokens_processed + allocated_locs, :, :].permute(1, 2, 0)
k_cache[block_id, :, :, :allocated_locs, :] = k_block
v_cache[block_id, :, :, :allocated_locs] = v_block
num_tokens_processed += allocated_locs
block_id += 1
return block_tables
def mock_alloc_single_token(block_tables: torch.Tensor, context_lengths: torch.Tensor, block_size: int) -> None:
# Allocate 1 token on the block table for each seqs in block tables.
# It won't change provided context_lengths.
# Consider max_block_id as the last physical block allocated
# NOTE It assumes all the blocks preceding this block have been allocated
max_block_id = torch.max(block_tables).item()
# the indices on each block table representing the cache block to be allocated one more token
alloc_local_block_indices = context_lengths // block_size
# offsets of the token to be allocated on the target block (for each seq)
alloc_block_offsets = context_lengths % block_size
require_new_block = alloc_block_offsets == 0
new_block_ids = torch.arange(
max_block_id + 1,
max_block_id + 1 + require_new_block.sum(),
dtype=block_tables.dtype,
device=block_tables.device,
)
if new_block_ids.numel():
new_block_alloc_local_indices = alloc_local_block_indices[require_new_block]
block_tables[require_new_block, new_block_alloc_local_indices] = new_block_ids
def generate_caches_and_block_tables(
k_unpad, v_unpad, kv_lengths, bsz, max_num_blocks_per_seq, block_size, dtype=torch.float16, device="cuda"
) -> Tuple[torch.Tensor, ...]:
# Mock generation of k/v blocked caches and block tables from providied kv unpad and seq lengths
# k_unpad/v_unpad [num_total_tokens, num_kv_heads, head_dim]
_, num_kv_heads, head_dim = k_unpad.shape
cache_shape = (bsz * max_num_blocks_per_seq, num_kv_heads, head_dim, block_size)
k_cache = torch.zeros(size=cache_shape, dtype=dtype, device=device)
v_cache = torch.zeros(size=cache_shape, dtype=dtype, device=device)
# Mock allocation on block tables as well as blocked kv caches
block_tables = mock_alloc_block_table_and_kvcache(
k_unpad, v_unpad, k_cache, v_cache, kv_lengths, bsz, max_num_blocks_per_seq, block_size
)
return k_cache, v_cache, block_tables
def generate_caches_and_block_tables_v2(
k_unpad, v_unpad, kv_lengths, bsz, max_num_blocks_per_seq, block_size, dtype=torch.float16, device="cuda"
) -> Tuple[torch.Tensor, ...]:
# Mock generation of k/v blocked caches and block tables from providied kv unpad and seq lengths
# k_unpad/v_unpad [num_total_tokens, num_kv_heads, head_dim]
_, num_kv_heads, head_dim = k_unpad.shape
cache_shape = (bsz * max_num_blocks_per_seq, num_kv_heads, block_size, head_dim)
k_cache = torch.zeros(size=cache_shape, dtype=dtype, device=device)
v_cache = torch.zeros(size=cache_shape, dtype=dtype, device=device)
# Mock allocation on block tables as well as blocked kv caches
block_tables = mock_alloc_block_table_and_kvcache_v2(
k_unpad, v_unpad, k_cache, v_cache, kv_lengths, bsz, max_num_blocks_per_seq, block_size
)
return k_cache, v_cache, block_tables
def generate_caches_and_block_tables_v3(
k_unpad, v_unpad, kv_lengths, bsz, max_num_blocks_per_seq, block_size, dtype=torch.float16, device="cuda"
) -> Tuple[torch.Tensor, ...]:
# Mock generation of k/v blocked caches and block tables from providied kv unpad and seq lengths
# k_unpad/v_unpad [num_total_tokens, num_kv_heads, head_dim]
_, num_kv_heads, head_dim = k_unpad.shape
x = 16 // torch.tensor([], dtype=dtype).element_size()
k_cache_shape = (bsz * max_num_blocks_per_seq, num_kv_heads, head_dim // x, block_size, x)
v_cache_shape = (bsz * max_num_blocks_per_seq, num_kv_heads, block_size, head_dim)
k_cache = torch.zeros(size=k_cache_shape, dtype=dtype, device=device)
v_cache = torch.zeros(size=v_cache_shape, dtype=dtype, device=device)
# Mock allocation on block tables as well as blocked kv caches
block_tables = mock_alloc_block_table_and_kvcache_v3(
k_unpad, v_unpad, k_cache, v_cache, kv_lengths, bsz, max_num_blocks_per_seq, block_size
)
return k_cache, v_cache, block_tables
def generate_caches_and_block_tables_vllm(
k_unpad, v_unpad, kv_lengths, bsz, max_num_blocks_per_seq, block_size, dtype=torch.float16, device="cuda"
) -> Tuple[torch.Tensor, ...]:
# Mock generation of k/v blocked caches and block tables from providied kv unpad and seq lengths
# k_unpad/v_unpad [num_total_tokens, num_kv_heads, head_dim]
_, num_kv_heads, head_dim = k_unpad.shape
x = 16 // torch.tensor([], dtype=dtype).element_size()
k_cache_shape = (bsz * max_num_blocks_per_seq, num_kv_heads, head_dim // x, block_size, x)
v_cache_shape = (bsz * max_num_blocks_per_seq, num_kv_heads, head_dim, block_size)
k_cache = torch.zeros(size=k_cache_shape, dtype=dtype, device=device)
v_cache = torch.zeros(size=v_cache_shape, dtype=dtype, device=device)
# Mock allocation on block tables as well as blocked kv caches
block_tables = mock_alloc_block_table_and_kvcache_vllm(
k_unpad, v_unpad, k_cache, v_cache, kv_lengths, bsz, max_num_blocks_per_seq, block_size
)
return k_cache, v_cache, block_tables
def convert_kv_unpad_to_padded(
k_unpad: torch.Tensor, kv_seq_lengths: torch.Tensor, bsz: int, max_seq_len: int
) -> torch.Tensor:
# Rebuild (batched) k/v with padding to be used by torch attention
# input k_unpad/v_unpad [num_total_tokens, num_kv_heads, head_dim]
# returns k/v padded [bsz, num_kv_heads, max_seq_len, head_dim]
_, num_kv_heads, head_dim = k_unpad.shape
k_torch = torch.zeros((bsz, max_seq_len, num_kv_heads, head_dim), dtype=k_unpad.dtype, device=k_unpad.device)
prev_len_sum = 0
for i, seq_len in enumerate(kv_seq_lengths.tolist()):
# left-side padding
k_torch[i, -seq_len:, :, :] = k_unpad[prev_len_sum : prev_len_sum + seq_len]
prev_len_sum += seq_len
k_torch = k_torch.transpose(1, 2)
return k_torch

View File

@@ -0,0 +1,179 @@
import pytest
import torch
from packaging import version
from colossalai.inference.modeling.models.nopadding_baichuan import get_alibi_slopes
from colossalai.kernel.triton import context_attention_unpadded
from colossalai.utils import get_current_device
from tests.test_infer.test_kernels.triton.kernel_utils import (
generate_caches_and_block_tables_v2,
generate_caches_and_block_tables_v3,
torch_attn_ref,
)
try:
import triton # noqa
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
print("please install triton from https://github.com/openai/triton")
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4")
HEAD_DIM = 32
def _fill_with_neg_inf(t):
return t.float().fill_(float("-inf")).type_as(t)
# alibi mask calculation adapted from https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat/blob/main/modeling_baichuan.py
def generate_alibi_mask(slopes, num_heads, max_seq_len, device):
token_position = torch.arange(max_seq_len, device=device) - max_seq_len + 1
token_position = token_position.unsqueeze(0).unsqueeze(0).expand(num_heads, -1, -1)
diag = torch.diag(token_position[0])
token_position = token_position - diag.unsqueeze(0).unsqueeze(0).transpose(-1, -2)
alibi = slopes.unsqueeze(1).unsqueeze(1) * token_position
alibi = alibi.view(num_heads, 1, max_seq_len)
alibi_mask = torch.triu(_fill_with_neg_inf(torch.zeros([max_seq_len, max_seq_len], device=device)), 1)
alibi_mask = alibi_mask.unsqueeze(0) + alibi
return alibi_mask
def torch_attn_unpad(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
context_lengths: torch.Tensor,
num_heads: int,
num_kv_heads: int,
slopes: torch.Tensor = None,
):
# Process sequence one by one and concatenate them together.
# q,k,v [num_tokens(sum(context_lengths)), num_heads, head_dim]
assert context_lengths.dim() == 1, "context_lengths should be a 1D tensor"
_, num_heads, head_dim = q.shape
out_torch = []
start_idx = 0
for seq_i in range(len(context_lengths)):
end_idx = start_idx + context_lengths[seq_i].item()
seq_len = end_idx - start_idx
mask = torch.tril(torch.ones(1, 1, seq_len, seq_len), diagonal=0).to(device=q.device)
mask[mask == 0.0] = float("-inf")
if slopes is not None:
alibi_mask = generate_alibi_mask(slopes, num_heads, seq_len, q.device)
mask = mask + alibi_mask
torch_attn_ref_out = torch_attn_ref(
q[start_idx:end_idx].unsqueeze(0).transpose(1, 2),
k[start_idx:end_idx].unsqueeze(0).transpose(1, 2),
v[start_idx:end_idx].unsqueeze(0).transpose(1, 2),
mask,
1, # set bsz as 1 as we're processing sequence one by one
seq_len,
seq_len,
num_heads,
num_kv_heads,
head_dim,
)
out_torch.append(torch_attn_ref_out.squeeze(0))
start_idx = end_idx
return torch.cat(out_torch, dim=0)
@pytest.mark.skipif(not (HAS_TRITON and TRITON_CUDA_SUPPORT), reason="requires triton")
@pytest.mark.parametrize("bsz", [4, 7, 32])
@pytest.mark.parametrize("block_size", [16, 32, 64])
@pytest.mark.parametrize("max_num_blocks_per_seq", [8, 32])
@pytest.mark.parametrize("num_attn_heads", [16])
@pytest.mark.parametrize("kv_group_num", [1, 2, 16])
@pytest.mark.parametrize("same_context_len", [True, False])
@pytest.mark.parametrize("use_alibi_slopes", [True, False])
@pytest.mark.parametrize("use_new_kcache_layout", [True, False])
def test_context_attention(
bsz: int,
block_size: int,
max_num_blocks_per_seq: int,
num_attn_heads: int,
kv_group_num: int,
same_context_len: bool,
use_alibi_slopes: bool,
use_new_kcache_layout: bool,
):
if use_new_kcache_layout and use_alibi_slopes:
# TODO(yuanheng-zhao): Since the alibi kernel is pretty similar to the original one,
# the code (alibi kernel) will be refactored later to avoid code duplication, when
# the whole triton flow with new k cache layout has been supported and tested.
# And tests for the alibi kernel using new kcache layout will be added then.
return
torch.manual_seed(123)
# It's necessary to clear cache here.
torch.cuda.empty_cache()
torch.cuda.synchronize()
torch.cuda.reset_peak_memory_stats()
num_kv_heads = num_attn_heads // kv_group_num
assert isinstance(num_kv_heads, int) and num_kv_heads > 0, "Invalid number of kv heads."
max_seq_len = max_num_blocks_per_seq * block_size
dtype = torch.float16
device = get_current_device()
alibi_slopes = None
if use_alibi_slopes:
alibi_slopes = get_alibi_slopes(num_attn_heads, device)
if same_context_len:
context_lengths = torch.tensor([max_seq_len for _ in range(bsz)], dtype=torch.int32, device=device)
else:
context_lengths = torch.randint(low=1, high=max_seq_len, size=(bsz,), dtype=torch.int32, device=device)
num_tokens = torch.sum(context_lengths).item()
qkv_size = (num_tokens, num_attn_heads + 2 * num_kv_heads, HEAD_DIM)
qkv_unpad = torch.empty(size=qkv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5)
q_unpad, k_unpad, v_unpad = torch.split(qkv_unpad, [num_attn_heads, num_kv_heads, num_kv_heads], dim=-2)
q_unpad = q_unpad.contiguous()
if use_new_kcache_layout:
k_cache_ref, v_cache_ref, block_tables = generate_caches_and_block_tables_v3(
k_unpad, v_unpad, context_lengths, bsz, max_num_blocks_per_seq, block_size, dtype, device
)
else:
k_cache_ref, v_cache_ref, block_tables = generate_caches_and_block_tables_v2(
k_unpad, v_unpad, context_lengths, bsz, max_num_blocks_per_seq, block_size, dtype, device
)
block_tables = block_tables.to(device=device)
k_cache_triton = torch.zeros_like(k_cache_ref)
v_cache_triton = torch.zeros_like(v_cache_ref)
_, num_heads, head_dim = q_unpad.shape
out_triton = context_attention_unpadded(
q_unpad,
k_unpad,
v_unpad,
k_cache_triton,
v_cache_triton,
context_lengths,
block_tables,
block_size,
alibi_slopes=alibi_slopes,
use_new_kcache_layout=use_new_kcache_layout,
)
out_triton = out_triton.view(-1, num_heads, head_dim)
out_torch = torch_attn_unpad(q_unpad, k_unpad, v_unpad, context_lengths, num_attn_heads, num_kv_heads, alibi_slopes)
assert out_torch.shape == out_triton.shape
assert torch.allclose(out_torch, out_triton, atol=1e-3)
assert torch.equal(k_cache_ref, k_cache_triton)
assert torch.equal(v_cache_ref, v_cache_triton)
if __name__ == "__main__":
test_context_attention(4, 32, 8, 16, 1, True, True, True)

View File

@@ -0,0 +1,197 @@
import numpy as np
import pytest
import torch
from packaging import version
from colossalai.inference.modeling.models.nopadding_baichuan import get_alibi_slopes
from colossalai.kernel.triton import flash_decoding_attention
from colossalai.utils import get_current_device
from tests.test_infer.test_kernels.triton.kernel_utils import (
convert_kv_unpad_to_padded,
create_attention_mask,
generate_caches_and_block_tables_v2,
generate_caches_and_block_tables_v3,
torch_attn_ref,
)
from tests.test_infer.test_kernels.triton.test_context_attn_unpad import generate_alibi_mask
try:
import triton # noqa
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
print("please install triton from https://github.com/openai/triton")
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4")
HEAD_DIM = 128
def numpy_allclose(x, y, rtol, atol):
x_numpy = x.detach().cpu().numpy()
y_numpy = y.detach().cpu().numpy()
np.testing.assert_allclose(x_numpy, y_numpy, rtol=rtol, atol=atol)
def prepare_data(
bsz: int,
num_attn_heads: int,
num_kv_heads: int,
head_dim: int,
same_context_len: bool,
q_len: int,
max_kv_seq_len: int,
dtype=torch.float16,
device="cuda",
):
# Use the provided maximum sequence length for each sequence when testing with teh same context length,
# otherwise generate random context lengths.
# returns
# q [bsz, num_attn_heads, q_len, head_dim]
# k_unpad/v_unpad [num_tokens, num_kv_heads, head_dim]
kv_lengths = (
torch.tensor([max_kv_seq_len for _ in range(bsz)], dtype=torch.int32, device=device)
if same_context_len
else torch.randint(low=1, high=max_kv_seq_len, size=(bsz,), dtype=torch.int32, device=device)
)
num_tokens = torch.sum(kv_lengths).item()
q_size = (bsz, q_len, num_attn_heads, head_dim)
q = torch.empty(size=q_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5).transpose(1, 2)
kv_size = (num_tokens, 2 * num_kv_heads, head_dim)
kv_unpad = torch.empty(size=kv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5)
k_unpad, v_unpad = torch.split(kv_unpad, [num_kv_heads, num_kv_heads], dim=-2)
return q, k_unpad, v_unpad, kv_lengths
@pytest.mark.skipif(not (HAS_TRITON and TRITON_CUDA_SUPPORT), reason="requires triton")
@pytest.mark.parametrize("bsz", [4, 7, 32])
@pytest.mark.parametrize("block_size", [16, 32, 64])
@pytest.mark.parametrize("max_num_blocks_per_seq", [8, 32])
@pytest.mark.parametrize("num_attn_heads", [16])
@pytest.mark.parametrize("kv_group_num", [1, 2, 16])
@pytest.mark.parametrize("same_context_len", [True, False])
@pytest.mark.parametrize("q_len", [1, 5])
@pytest.mark.parametrize("use_alibi_slopes", [True, False])
@pytest.mark.parametrize("use_new_kcache_layout", [True, False])
def test_flash_decoding(
bsz: int,
block_size: int,
max_num_blocks_per_seq: int,
num_attn_heads: int,
kv_group_num: int,
same_context_len: bool,
q_len: int,
use_alibi_slopes: bool,
use_new_kcache_layout: bool,
):
if use_new_kcache_layout and use_alibi_slopes:
# TODO(yuanheng-zhao): Since the alibi kernel is pretty similar to the original one,
# the code (alibi kernel) will be refactored later to avoid code duplication, when
# the whole triton flow with new k cache layout has been supported and tested.
# And tests for the alibi kernel using new kcache layout will be added then.
pytest.skip("Alibi kernel does not support new kcache layout yet.")
torch.manual_seed(123)
torch.cuda.empty_cache()
torch.cuda.synchronize()
torch.cuda.reset_peak_memory_stats()
num_kv_heads = num_attn_heads // kv_group_num
assert isinstance(num_kv_heads, int) and num_kv_heads > 0, "Invalid number of kv heads."
max_seq_len = block_size * max_num_blocks_per_seq
dtype = torch.float16
device = get_current_device()
if use_alibi_slopes:
alibi_slopes = get_alibi_slopes(num_attn_heads, device)
# Currently, alibi flash decoding does not support q_len>1.
q_len = 1
else:
alibi_slopes = None
q, k_unpad, v_unpad, kv_lengths = prepare_data(
bsz, num_attn_heads, num_kv_heads, HEAD_DIM, same_context_len, q_len, max_seq_len, dtype, device
)
# The maximum sequence length in the batch (if context lengths randomly generated)
max_kv_len_in_b = kv_lengths.max().item()
k_torch = convert_kv_unpad_to_padded(k_unpad, kv_lengths, bsz, max_kv_len_in_b)
v_torch = convert_kv_unpad_to_padded(v_unpad, kv_lengths, bsz, max_kv_len_in_b)
attention_mask = create_attention_mask(kv_lengths, bsz, q_len, max_kv_len_in_b, q.device)
if use_alibi_slopes:
alibi_mask = generate_alibi_mask(alibi_slopes, num_attn_heads, max_kv_len_in_b, q.device)
attention_mask = attention_mask + alibi_mask
if q_len == 1:
if len(attention_mask.size()) == 4:
attention_mask = attention_mask[:, :, -1:, :]
else:
attention_mask = attention_mask[:, -1:, :]
out_torch = torch_attn_ref(
q, k_torch, v_torch, attention_mask, bsz, q_len, max_kv_len_in_b, num_attn_heads, num_kv_heads, HEAD_DIM
)
if use_new_kcache_layout:
k_cache, v_cache, block_tables = generate_caches_and_block_tables_v3(
k_unpad, v_unpad, kv_lengths, bsz, max_num_blocks_per_seq, block_size, dtype, device
)
else:
k_cache, v_cache, block_tables = generate_caches_and_block_tables_v2(
k_unpad, v_unpad, kv_lengths, bsz, max_num_blocks_per_seq, block_size, dtype, device
)
block_tables = block_tables.to(device=device)
# The maximum block length splitted on kv should be the kv cache block size
kv_max_split_num = (max_kv_len_in_b + block_size - 1) // block_size
output = torch.empty((bsz * q_len, num_attn_heads, HEAD_DIM), dtype=q.dtype, device=q.device)
mid_output = torch.empty(
size=(bsz * q_len, num_attn_heads, kv_max_split_num, HEAD_DIM), dtype=torch.float32, device=q.device
)
mid_output_lse = torch.empty(
size=(bsz * q_len, num_attn_heads, kv_max_split_num), dtype=torch.float32, device=q.device
)
sm_scale = 1.0 / (HEAD_DIM**0.5)
# Here we use different methods to hide the q_len dimension,
# refer to attention forward function in modeling.
if q_len > 1:
q = q.transpose(1, 2).contiguous() # [bsz, q_len, num_heads, head_dim]
q = q.view(-1, q.size(-2), q.size(-1)) # [bsz * q_len, num_heads, head_dim]
else:
q = q.squeeze(2)
assert q.shape == (bsz * q_len, num_attn_heads, HEAD_DIM)
out_triton = flash_decoding_attention(
q,
k_cache,
v_cache,
kv_lengths,
block_tables,
block_size,
max_kv_len_in_b,
output,
mid_output,
mid_output_lse,
alibi_slopes=alibi_slopes,
sm_scale=sm_scale,
kv_group_num=kv_group_num,
q_len=q_len,
use_new_kcache_layout=use_new_kcache_layout,
) # [bsz * q_len, num_heads, head_dim]
assert out_torch.shape == out_triton.shape
rtol = 1e-4
# After the shape becomes larger, some data elements are too small, leading to excessively large relative errors.
if bsz == 32 and use_alibi_slopes:
rtol = 100
numpy_allclose(out_torch, out_triton, atol=1e-3, rtol=rtol)
if __name__ == "__main__":
test_flash_decoding(16, 32, 32, 16, 1, True, 1, use_alibi_slopes=False, use_new_kcache_layout=True)

View File

@@ -0,0 +1,50 @@
from copy import deepcopy
import pytest
import torch
from packaging import version
from colossalai.kernel.triton.fused_rotary_embedding import fused_rotary_embedding
from colossalai.kernel.triton.no_pad_rotary_embedding import rotary_embedding
from colossalai.kernel.triton.rotary_cache_copy import get_xine_cache
try:
import triton # noqa
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
print("please install triton from https://github.com/openai/triton")
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4")
@pytest.mark.skipif(not (HAS_TRITON and TRITON_CUDA_SUPPORT), reason="requires triton")
def test_fused_rotary_emb():
num_tokens = 20
num_kv_heads = 32
head_dim = 64
dtype = torch.float32
q_shape = (num_tokens, num_kv_heads, head_dim)
q = -2.3 + 0.5 * torch.randn(q_shape, dtype=dtype, device="cuda")
q_copy = deepcopy(q)
k_shape = (num_tokens, num_kv_heads, head_dim)
k = -2.3 + 0.5 * torch.randn(k_shape, dtype=dtype, device="cuda")
k_copy = deepcopy(k)
cos_shape = (1024, head_dim)
lengths = torch.tensor([3, 4, 6, 7], device="cuda")
cos_cache = -1.2 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
sin_cache = -2.0 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
cos, sin = get_xine_cache(lengths, cos_cache[:, : head_dim // 2], sin_cache[:, : head_dim // 2])
rotary_embedding(q, k, cos, sin)
fused_rotary_embedding(q_copy, k_copy, cos_cache, sin_cache, lengths)
torch.allclose(q, q_copy)
torch.allclose(k, k_copy)
if __name__ == "__main__":
test_fused_rotary_emb()

View File

@@ -0,0 +1,168 @@
import pytest
import torch
from packaging import version
from colossalai.kernel.triton import copy_k_to_blocked_cache, copy_kv_to_blocked_cache
from colossalai.utils import get_current_device
from tests.test_infer.test_kernels.triton.kernel_utils import (
generate_caches_and_block_tables_v2,
generate_caches_and_block_tables_v3,
mock_alloc_single_token,
)
try:
import triton # noqa
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
print("please install triton from https://github.com/openai/triton")
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4")
HEAD_DIM = 32
def prepare_data(
bsz,
num_kv_heads,
head_dim,
block_size,
max_num_blocks_per_seq,
same_context_len,
max_seq_len,
n=1,
device="cuda",
dtype=torch.float16,
use_new_kcache_layout=False,
):
assert max_seq_len > n, "max_seq_len must be greater than n"
past_kv_seq_lengths = (
torch.tensor([max_seq_len - n for _ in range(bsz)], dtype=torch.int32, device=device)
if same_context_len
else torch.randint(low=1, high=max_seq_len - n, size=(bsz,), dtype=torch.int32, device=device)
)
num_tokens = torch.sum(past_kv_seq_lengths).item()
kv_size = (num_tokens, 2 * num_kv_heads, head_dim)
kv_unpad = torch.empty(size=kv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5)
k_unpad, v_unpad = torch.split(kv_unpad, [num_kv_heads, num_kv_heads], dim=-2)
if use_new_kcache_layout:
k_cache, v_cache, block_tables = generate_caches_and_block_tables_v3(
k_unpad, v_unpad, past_kv_seq_lengths, bsz, max_num_blocks_per_seq, block_size, dtype=dtype, device=device
)
else:
k_cache, v_cache, block_tables = generate_caches_and_block_tables_v2(
k_unpad, v_unpad, past_kv_seq_lengths, bsz, max_num_blocks_per_seq, block_size, dtype=dtype, device=device
)
block_tables = block_tables.to(device=device)
new_k = torch.randn((bsz, n, num_kv_heads, head_dim), dtype=dtype, device=device)
new_v = torch.randn((bsz, n, num_kv_heads, head_dim), dtype=dtype, device=device)
# mock allocating blocks for the new k/v and update block tables
for _ in range(n):
mock_alloc_single_token(block_tables, past_kv_seq_lengths, block_size)
past_kv_seq_lengths += 1
return new_k, new_v, k_cache, v_cache, past_kv_seq_lengths, block_tables
@pytest.mark.skipif(not (HAS_TRITON and TRITON_CUDA_SUPPORT), reason="requires triton")
@pytest.mark.parametrize("bsz", [4, 7, 32])
@pytest.mark.parametrize("block_size", [16, 32, 64])
@pytest.mark.parametrize("max_num_blocks_per_seq", [8, 32])
@pytest.mark.parametrize("num_kv_heads", [16])
@pytest.mark.parametrize("same_context_len", [True, False])
@pytest.mark.parametrize("n_tokens", [1, 5])
@pytest.mark.parametrize("use_new_kcache_layout", [True, False])
def test_copy_kv_to_caches(
bsz: int,
block_size: int,
max_num_blocks_per_seq: int,
num_kv_heads: int,
same_context_len: bool,
n_tokens: int,
use_new_kcache_layout: bool,
):
torch.manual_seed(123)
torch.cuda.empty_cache()
torch.cuda.synchronize()
torch.cuda.reset_peak_memory_stats()
max_seq_len = block_size * max_num_blocks_per_seq
dtype = torch.float16
device = get_current_device()
new_k, new_v, k_cache, v_cache, kv_seq_lengths, block_tables = prepare_data(
bsz,
num_kv_heads,
HEAD_DIM,
block_size,
max_num_blocks_per_seq,
same_context_len,
max_seq_len,
n_tokens,
device=device,
dtype=dtype,
use_new_kcache_layout=use_new_kcache_layout,
)
k_source = new_k.view(-1, new_k.size(-2), new_k.size(-1))
v_source = new_v.view(-1, new_v.size(-2), new_v.size(-1))
k_cache_copy = k_cache.detach().clone()
past_kv_seq_lengths = kv_seq_lengths - n_tokens
target_block_ids = block_tables[range(0, block_tables.size(0)), past_kv_seq_lengths // block_size]
offsets_in_block = past_kv_seq_lengths % block_size
# Copy k (or v) to k (or v) cache
copy_k_to_blocked_cache(
new_k, k_cache, kv_seq_lengths, block_tables, n=n_tokens, use_new_kcache_layout=use_new_kcache_layout
)
# Reshape target k from k cache to compare if matching with original tensor
# Mainly to handle cases of n_tokens > 1
k_target = []
for i in range(bsz):
block_table = block_tables[i]
curr_kv_len = past_kv_seq_lengths[i].item()
offset = offsets_in_block[i].item()
tokens_left = n_tokens
while tokens_left > 0:
tokens_to_fill = min(block_size - offset, tokens_left)
curr_block_id = block_table[curr_kv_len // block_size]
if use_new_kcache_layout:
k_target.append(k_cache[curr_block_id, :, :, offset : offset + tokens_to_fill, :])
else:
k_target.append(k_cache[curr_block_id, :, offset : offset + tokens_to_fill, :])
curr_kv_len += tokens_to_fill
tokens_left -= tokens_to_fill
offset = 0
if use_new_kcache_layout:
k_target = torch.concat(k_target, dim=2).permute(2, 0, 1, 3).contiguous()
k_target = k_target.reshape(bsz * n_tokens, num_kv_heads, HEAD_DIM)
else:
k_target = torch.concat(k_target, dim=1).transpose(0, 1).contiguous() # [bsz * n, num_kv_heads, head_dim]
assert k_target.shape == k_source.shape
assert torch.equal(k_target, k_source)
if n_tokens == 1:
# Copy k and v to k/v caches
k_cache = k_cache_copy
copy_kv_to_blocked_cache(
new_k, new_v, k_cache, v_cache, kv_seq_lengths, block_tables, use_new_kcache_layout=use_new_kcache_layout
)
if use_new_kcache_layout:
k_target = k_cache[target_block_ids, :, :, offsets_in_block, :]
k_target = k_target.contiguous().reshape(bsz * n_tokens, num_kv_heads, HEAD_DIM)
else:
k_target = k_cache[target_block_ids, :, offsets_in_block, :]
assert k_target.shape == k_source.shape
assert torch.equal(k_target, k_source)
v_target = v_cache[target_block_ids, :, offsets_in_block, :]
assert v_target.shape == v_source.shape
assert torch.equal(v_target, v_source)
if __name__ == "__main__":
test_copy_kv_to_caches(4, 32, 8, 16, True, n_tokens=1)

View File

@@ -0,0 +1,55 @@
import pytest
import torch
from packaging import version
from transformers.models.llama.modeling_llama import LlamaRMSNorm
from colossalai.kernel.triton import rms_layernorm
from colossalai.testing.utils import parameterize
try:
import triton # noqa
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
print("please install triton from https://github.com/openai/triton")
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4")
@pytest.mark.skipif(
not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4"
)
@parameterize("M", [2, 4, 8, 16])
@parameterize("N", [64, 128])
def test_layer_norm(M, N):
dtype = torch.float16
eps = 1e-5
x_shape = (M, N)
w_shape = (x_shape[-1],)
weight = torch.ones(w_shape, dtype=dtype, device="cuda")
residual = torch.rand(x_shape, dtype=dtype, device="cuda")
residual_copy = residual.clone()
rms_norm = LlamaRMSNorm(hidden_size=N, eps=eps).cuda()
x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda")
x_copy = x.clone()
y_triton, _ = rms_layernorm(x, weight, eps=eps)
y_llama = rms_norm.forward(x).to(dtype)
assert y_triton.shape == y_llama.shape
assert torch.allclose(y_triton, y_llama, atol=1e-5, rtol=1e-3)
y_triton, residual = rms_layernorm(x, weight, eps=eps, residual=residual)
x = x_copy + residual_copy
y_llama = rms_norm.forward(x).to(dtype)
assert y_triton.shape == y_llama.shape
assert torch.allclose(y_triton, y_llama, atol=1e-5, rtol=1e-3)
assert torch.allclose(x, residual, atol=1e-5, rtol=1e-3)
if __name__ == "__main__":
test_layer_norm()

View File

@@ -0,0 +1,100 @@
import pytest
import torch
from packaging import version
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb
from colossalai.kernel.triton import decoding_fused_rotary_embedding
from tests.test_infer.test_kernels.triton.kernel_utils import (
mock_alloc_block_table_and_kvcache_v2,
mock_alloc_block_table_and_kvcache_v3,
)
try:
import triton # noqa
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
print("please install triton from https://github.com/openai/triton")
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4")
def torch_rotary_emb(x, cos, sin):
seq_len, h, dim = x.shape
x0 = x[:, :, 0 : dim // 2]
x1 = x[:, :, dim // 2 : dim]
cos = cos.view((seq_len, 1, dim // 2))
sin = sin.view((seq_len, 1, dim // 2))
o0 = x0 * cos - x1 * sin
o1 = x0 * sin + x1 * cos
return torch.cat((o0, o1), dim=-1)
@pytest.mark.skipif(
not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4"
)
@pytest.mark.parametrize("BATCH_SIZE", [4])
@pytest.mark.parametrize("SEQ_LEN", [64])
@pytest.mark.parametrize("H", [32])
@pytest.mark.parametrize("D", [64])
@pytest.mark.parametrize("dtype", [torch.float32])
@pytest.mark.parametrize("use_new_kcache_layout", [True, False])
def test_rotary_emb(BATCH_SIZE, SEQ_LEN, H, D, dtype, use_new_kcache_layout):
TOTAL_TOKENS = BATCH_SIZE * SEQ_LEN
# our crafted op equals to Transformers
x0 = torch.randn(TOTAL_TOKENS, SEQ_LEN, D)
x1 = torch.randn(TOTAL_TOKENS, SEQ_LEN, D)
emb = LlamaRotaryEmbedding(D)
cos, sin = emb(x0, TOTAL_TOKENS)
cos_2 = cos[:, :32]
sin_2 = sin[:, :32]
position_ids = torch.arange(TOTAL_TOKENS)
embd_x0, _ = apply_rotary_pos_emb(x0, x1, cos, sin, position_ids)
embd_stimulated_x = torch_rotary_emb(x0, cos_2, sin_2)
assert torch.allclose(embd_x0, embd_stimulated_x)
# create data
block_size = 32
max_num_blocks_per_seq = 4
q_shape = (TOTAL_TOKENS, H, D)
q = -2.3 + 0.5 * torch.randn(q_shape, dtype=dtype, device="cuda")
k_shape = (TOTAL_TOKENS, H, D)
k = -2.3 + 0.5 * torch.randn(k_shape, dtype=dtype, device="cuda")
v = torch.randn_like(k)
new_k = torch.randn((BATCH_SIZE, H, D), dtype=dtype, device="cuda")
new_q = torch.randn_like(new_k)
new_v = torch.randn_like(new_k)
cos_shape = (TOTAL_TOKENS, D // 2)
cos = -1.2 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
sin = -2.0 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
past_kv_seq_lengths = torch.tensor([SEQ_LEN - 1 for _ in range(BATCH_SIZE)], dtype=torch.int32, device="cuda")
v_cache_shape = (BATCH_SIZE * max_num_blocks_per_seq, H, block_size, D)
v_cache = torch.zeros(size=v_cache_shape, dtype=dtype, device="cuda")
if use_new_kcache_layout:
x = 16 // torch.tensor([], dtype=dtype).element_size()
kcache_shape = (BATCH_SIZE * max_num_blocks_per_seq, H, D // x, block_size, x)
k_cache = torch.zeros(size=kcache_shape, dtype=dtype, device="cuda")
block_tables = mock_alloc_block_table_and_kvcache_v3(
k, v, k_cache, v_cache, past_kv_seq_lengths, BATCH_SIZE, max_num_blocks_per_seq, block_size
)
else:
k_cache = torch.zeros_like(v_cache)
block_tables = mock_alloc_block_table_and_kvcache_v2(
k, v, k_cache, v_cache, past_kv_seq_lengths, BATCH_SIZE, max_num_blocks_per_seq, block_size
)
kv_seq_lengths = past_kv_seq_lengths + 1
block_tables = block_tables.to(device="cuda")
q_ref = torch_rotary_emb(new_q, cos[:BATCH_SIZE], sin[:BATCH_SIZE])
decoding_fused_rotary_embedding(
new_q, new_k, new_v, cos, sin, k_cache, v_cache, block_tables, kv_seq_lengths, use_new_kcache_layout
)
assert torch.allclose(new_q, q_ref, atol=1e-4, rtol=1e-4)
if __name__ == "__main__":
test_rotary_emb(4, 64, 32, 64, torch.float32, use_new_kcache_layout=True)

View File

@@ -0,0 +1,66 @@
import pytest
import torch
from packaging import version
from colossalai.kernel.triton import get_xine_cache
try:
import triton # noqa
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
print("please install triton from https://github.com/openai/triton")
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4")
@torch.no_grad()
def get_cos_sin(lengths, cos_cache, sin_cache, is_prompts, dtype):
"""
Get cos and sin for the cache, and return nopad format.
Args:
lengths: shape(num_seqs,), stores lenghth of each sequence.
cos_cache: shape(max_rotary_position(e.g.2048), head_dim), cos cache constrcuted in model.
sin_cache: shape(max_rotary_position(e.g.2048), head_dim), sin cache constrcuted in model.
is_prompts: bool, mark if in prefill mode.
dtype: The data type of this inference process.
"""
if is_prompts:
index_arrays = [torch.arange(length) for length in lengths]
else:
index_arrays = [(length - 1).view(-1) for length in lengths]
indices = torch.cat(index_arrays, dim=-1)
cos_output = cos_cache[indices].to(dtype=dtype)
sin_output = sin_cache[indices].to(dtype=dtype)
return (cos_output, sin_output)
@pytest.mark.skipif(
not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4"
)
@pytest.mark.parametrize("BATCH_SIZE", [4])
@pytest.mark.parametrize("MAX_SEQ_LEN", [64])
@pytest.mark.parametrize("HEAD_DIM", [64])
@pytest.mark.parametrize("dtype", [torch.float32])
def test_get_xine_cache(BATCH_SIZE, MAX_SEQ_LEN, HEAD_DIM, dtype):
MAX_TOTAL_TOKENS = BATCH_SIZE * MAX_SEQ_LEN
cos_cache = torch.randn((MAX_TOTAL_TOKENS, HEAD_DIM), dtype=dtype, device="cuda")
sin_cache = torch.randn((MAX_TOTAL_TOKENS, HEAD_DIM), dtype=dtype, device="cuda")
lengths = torch.randint(2, MAX_SEQ_LEN, (BATCH_SIZE,), device="cuda")
# prefill
cos_ref, sin_ref = get_cos_sin(lengths, cos_cache, sin_cache, is_prompts=True, dtype=dtype)
cos, sin = get_xine_cache(lengths, cos_cache, sin_cache, is_prompts=True)
assert torch.allclose(cos, cos_ref)
assert torch.allclose(sin, sin_ref)
# decoding
ncos_ref, nsin_ref = get_cos_sin(lengths, cos_cache, sin_cache, is_prompts=False, dtype=dtype)
cos, sin = get_xine_cache(lengths, cos_cache, sin_cache, is_prompts=False)
assert torch.allclose(cos, ncos_ref)
assert torch.allclose(sin, nsin_ref)
if __name__ == "__main__":
test_get_xine_cache(4, 64, 256, torch.float32)