mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 01:06:00 +00:00
[Infer] Optimize Blocked KVCache And Kernels Using It (#5325)
* revise shape of kvcache (context attn kernel) * revise shape of kvcache (flash decoding kernel) * revise shape of kvcache (kvcache copy) and attn func * init of kvcache in kvcache manager * revise llama modeling * revise block size retrieval * use torch for rms_norm benchmarking * revise block size retrieval
This commit is contained in:
@@ -93,7 +93,7 @@ def check_cache_manager(test_config):
|
||||
assert len(cache_manager._cache_blocks) == num_blocks
|
||||
key_caches = cache_manager._kv_caches[0] # key caches for all the blocks in all the layers
|
||||
assert len(key_caches) == num_layers
|
||||
expected_kv_shape = (num_blocks, num_attention_heads, head_size, block_size)
|
||||
expected_kv_shape = (num_blocks, num_attention_heads, block_size, head_size)
|
||||
assert key_caches[0].shape == expected_kv_shape
|
||||
k_cache_block0, v_cache_block0 = cache_manager.get_physical_cache(0, 0)
|
||||
expected_kv_block_shape = expected_kv_shape[1:]
|
||||
|
@@ -1,20 +1,17 @@
|
||||
import pytest
|
||||
import torch
|
||||
from transformers.cache_utils import DynamicCache
|
||||
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
|
||||
from transformers.models.llama.configuration_llama import LlamaConfig
|
||||
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb
|
||||
|
||||
import colossalai
|
||||
from colossalai.inference.modeling.layers.attention import PagedAttention, convert_kvcache, copy_to_cache
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
|
||||
|
||||
def test_copy_to_cache():
|
||||
key = torch.ones((2, 11, 3, 3))
|
||||
key[0, 9, :, :] = 0
|
||||
key[1, -2:, :, :] = 0
|
||||
cache = torch.zeros(8, 3, 3, 8)
|
||||
cache = torch.zeros(8, 3, 8, 3)
|
||||
block_tables = torch.tensor([[0, 1], [2, 3]])
|
||||
lengths = torch.tensor([9, 8])
|
||||
cache = copy_to_cache(key, cache=cache, lengths=lengths, block_tables=block_tables, type="prefill")
|
||||
@@ -28,7 +25,7 @@ def test_copy_to_cache():
|
||||
|
||||
|
||||
def test_convert_kvcache():
|
||||
cache = torch.ones(8, 3, 3, 8)
|
||||
cache = torch.ones(8, 3, 8, 3)
|
||||
key = torch.ones(2, 1, 3, 3) + 1
|
||||
lengths = torch.tensor([10, 9])
|
||||
block_tables = torch.tensor([[0, 1], [2, 3]])
|
||||
@@ -43,8 +40,8 @@ def test_context_attention():
|
||||
"""
|
||||
attn = PagedAttention()
|
||||
q = k = v = torch.randn(8, 4, 4)
|
||||
k_cache = torch.empty(8, 4, 4, 8)
|
||||
v_cache = torch.empty(8, 4, 4, 8)
|
||||
k_cache = torch.empty(8, 4, 8, 4)
|
||||
v_cache = torch.empty(8, 4, 8, 4)
|
||||
context_lengths = torch.tensor(
|
||||
[
|
||||
8,
|
||||
@@ -136,23 +133,8 @@ def test_decoding_attention():
|
||||
assert torch.allclose(pad_attn_output, attn_output, atol=1e-3, rtol=1e-2)
|
||||
|
||||
|
||||
def check_attention_layer():
|
||||
if __name__ == "__main__":
|
||||
test_copy_to_cache()
|
||||
test_convert_kvcache()
|
||||
test_context_attention()
|
||||
test_decoding_attention()
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost")
|
||||
check_attention_layer()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_attention_layer():
|
||||
spawn(run_dist, 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_attention_layer()
|
||||
|
@@ -106,6 +106,40 @@ def mock_alloc_block_table_and_kvcache(
|
||||
return block_tables
|
||||
|
||||
|
||||
def mock_alloc_block_table_and_kvcache_v2(
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
k_cache: torch.Tensor,
|
||||
v_cache: torch.Tensor,
|
||||
context_lengths: torch.Tensor,
|
||||
num_seqs: int,
|
||||
max_num_blocks_per_seq: int,
|
||||
block_size: int,
|
||||
) -> torch.Tensor:
|
||||
"""Allocate block tables based on provided context lengths; and copy KV to blocked KV Cache."""
|
||||
block_id = 0
|
||||
block_tables = torch.full(size=(num_seqs, max_num_blocks_per_seq), fill_value=-1, dtype=torch.int32)
|
||||
num_tokens_processed = 0
|
||||
for i, seq_len in enumerate(context_lengths.tolist()):
|
||||
right_bound = (seq_len + block_size - 1) // block_size # open bound
|
||||
block_tables[i, :right_bound] = torch.arange(block_id, block_id + right_bound, dtype=torch.int32)
|
||||
# Manually fill kv caches by copying from k and v
|
||||
for i in range(right_bound):
|
||||
if i == right_bound - 1:
|
||||
allocated_locs = seq_len % block_size or block_size
|
||||
else:
|
||||
allocated_locs = block_size
|
||||
k_block = k[num_tokens_processed : num_tokens_processed + allocated_locs, :, :].permute(1, 0, 2)
|
||||
v_block = v[num_tokens_processed : num_tokens_processed + allocated_locs, :, :].permute(1, 0, 2)
|
||||
k_cache[block_id, :, :allocated_locs, :] = k_block
|
||||
v_cache[block_id, :, :allocated_locs, :] = v_block
|
||||
|
||||
num_tokens_processed += allocated_locs
|
||||
block_id += 1
|
||||
|
||||
return block_tables
|
||||
|
||||
|
||||
def mock_alloc_single_token(block_tables: torch.Tensor, context_lengths: torch.Tensor, block_size: int) -> None:
|
||||
# Allocate 1 token on the block table for each seqs in block tables.
|
||||
# It won't change provided context_lengths.
|
||||
@@ -146,6 +180,22 @@ def generate_caches_and_block_tables(
|
||||
return k_cache, v_cache, block_tables
|
||||
|
||||
|
||||
def generate_caches_and_block_tables_v2(
|
||||
k_unpad, v_unpad, kv_lengths, bsz, max_num_blocks_per_seq, block_size, dtype=torch.float16, device="cuda"
|
||||
) -> Tuple[torch.Tensor, ...]:
|
||||
# Mock generation of k/v blocked caches and block tables from providied kv unpad and seq lengths
|
||||
# k_unpad/v_unpad [num_total_tokens, num_kv_heads, head_dim]
|
||||
_, num_kv_heads, head_dim = k_unpad.shape
|
||||
cache_shape = (bsz * max_num_blocks_per_seq, num_kv_heads, block_size, head_dim)
|
||||
k_cache = torch.zeros(size=cache_shape, dtype=dtype, device=device)
|
||||
v_cache = torch.zeros(size=cache_shape, dtype=dtype, device=device)
|
||||
# Mock allocation on block tables as well as blocked kv caches
|
||||
block_tables = mock_alloc_block_table_and_kvcache_v2(
|
||||
k_unpad, v_unpad, k_cache, v_cache, kv_lengths, bsz, max_num_blocks_per_seq, block_size
|
||||
)
|
||||
return k_cache, v_cache, block_tables
|
||||
|
||||
|
||||
def convert_kv_unpad_to_padded(
|
||||
k_unpad: torch.Tensor, kv_seq_lengths: torch.Tensor, bsz: int, max_seq_len: int
|
||||
) -> torch.Tensor:
|
||||
|
@@ -6,7 +6,7 @@ from transformers.modeling_attn_mask_utils import AttentionMaskConverter
|
||||
from colossalai.inference.modeling.layers.attention import PagedAttention
|
||||
from colossalai.kernel.triton import context_attention_unpadded
|
||||
from colossalai.utils import get_current_device
|
||||
from tests.test_infer_ops.triton.kernel_utils import generate_caches_and_block_tables, torch_attn_ref
|
||||
from tests.test_infer_ops.triton.kernel_utils import generate_caches_and_block_tables_v2, torch_attn_ref
|
||||
|
||||
try:
|
||||
import triton # noqa
|
||||
@@ -93,7 +93,7 @@ def test_context_attention(
|
||||
q_unpad, k_unpad, v_unpad = torch.split(qkv_unpad, [num_attn_heads, num_kv_heads, num_kv_heads], dim=-2)
|
||||
q_unpad = q_unpad.contiguous()
|
||||
|
||||
k_cache_ref, v_cache_ref, block_tables = generate_caches_and_block_tables(
|
||||
k_cache_ref, v_cache_ref, block_tables = generate_caches_and_block_tables_v2(
|
||||
k_unpad, v_unpad, context_lengths, bsz, max_num_blocks_per_seq, block_size, dtype, device
|
||||
)
|
||||
block_tables = block_tables.to(device=device)
|
||||
@@ -148,7 +148,6 @@ def bench_kernel(
|
||||
|
||||
num_kv_heads = num_attn_heads // kv_group_num
|
||||
assert isinstance(num_kv_heads, int) and num_kv_heads > 0, "Invalid number of kv heads."
|
||||
block_size * max_num_blocks_per_seq
|
||||
dtype = torch.float16
|
||||
device = get_current_device()
|
||||
|
||||
@@ -162,7 +161,7 @@ def bench_kernel(
|
||||
qkv_unpad = torch.empty(size=qkv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5)
|
||||
q_unpad, k_unpad, v_unpad = torch.split(qkv_unpad, [num_attn_heads, num_kv_heads, num_kv_heads], dim=-2)
|
||||
q_unpad = q_unpad.contiguous()
|
||||
k_cache_ref, v_cache_ref, block_tables = generate_caches_and_block_tables(
|
||||
k_cache_ref, v_cache_ref, block_tables = generate_caches_and_block_tables_v2(
|
||||
k_unpad, v_unpad, context_lengths, bsz, max_num_blocks_per_seq, block_size, dtype, device
|
||||
)
|
||||
block_tables = block_tables.to(device=device)
|
||||
|
@@ -6,7 +6,7 @@ from colossalai.kernel.triton import flash_decoding_attention
|
||||
from colossalai.utils import get_current_device
|
||||
from tests.test_infer_ops.triton.kernel_utils import (
|
||||
convert_kv_unpad_to_padded,
|
||||
generate_caches_and_block_tables,
|
||||
generate_caches_and_block_tables_v2,
|
||||
prepare_padding_mask,
|
||||
torch_attn_ref,
|
||||
)
|
||||
@@ -38,6 +38,9 @@ def prepare_data(
|
||||
):
|
||||
# Use the provided maximum sequence length for each sequence when testing with teh same context length,
|
||||
# otherwise generate random context lengths.
|
||||
# returns
|
||||
# q [bsz, num_attn_heads, q_len, head_dim]
|
||||
# k_unpad/v_unpad [num_tokens, num_kv_heads, head_dim]
|
||||
kv_lengths = (
|
||||
torch.tensor([max_kv_seq_len for _ in range(bsz)], dtype=torch.int32, device=device)
|
||||
if same_context_len
|
||||
@@ -83,7 +86,7 @@ def test_flash_decoding(
|
||||
q, k_unpad, v_unpad, kv_seq_lengths = prepare_data(
|
||||
bsz, num_attn_heads, num_kv_heads, HEAD_DIM, same_context_len, Q_LEN, max_seq_len, dtype, device
|
||||
)
|
||||
k_cache, v_cache, block_tables = generate_caches_and_block_tables(
|
||||
k_cache, v_cache, block_tables = generate_caches_and_block_tables_v2(
|
||||
k_unpad, v_unpad, kv_seq_lengths, bsz, max_num_blocks_per_seq, block_size, dtype, device
|
||||
)
|
||||
block_tables = block_tables.to(device=device)
|
||||
@@ -180,7 +183,7 @@ def bench_kernel(
|
||||
)
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(fn, warmup=WARM_UPS, rep=REPS, quantiles=quantiles)
|
||||
if provider == "triton":
|
||||
k_cache, v_cache, block_tables = generate_caches_and_block_tables(
|
||||
k_cache, v_cache, block_tables = generate_caches_and_block_tables_v2(
|
||||
k_unpad, v_unpad, kv_lengths, bsz, max_num_blocks_per_seq, block_size, dtype, device
|
||||
)
|
||||
block_tables = block_tables.to(device=device)
|
||||
|
@@ -5,7 +5,7 @@ from packaging import version
|
||||
from colossalai.inference.modeling.layers.attention import copy_to_cache
|
||||
from colossalai.kernel.triton import copy_kv_to_blocked_cache
|
||||
from colossalai.utils import get_current_device
|
||||
from tests.test_infer_ops.triton.kernel_utils import mock_alloc_block_table_and_kvcache, mock_alloc_single_token
|
||||
from tests.test_infer_ops.triton.kernel_utils import generate_caches_and_block_tables_v2, mock_alloc_single_token
|
||||
|
||||
try:
|
||||
import triton # noqa
|
||||
@@ -17,6 +17,8 @@ except ImportError:
|
||||
|
||||
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4")
|
||||
|
||||
HEAD_DIM = 128
|
||||
|
||||
|
||||
def prepare_data(
|
||||
bsz,
|
||||
@@ -29,31 +31,27 @@ def prepare_data(
|
||||
device,
|
||||
dtype=torch.float16,
|
||||
):
|
||||
if same_context_len:
|
||||
# past_kv_seq_lengths in this test records the previous kv seq len
|
||||
# (not incorporating the current input whose seq len is 1)
|
||||
past_kv_seq_lengths = torch.tensor([max_seq_len - 1 for _ in range(bsz)], dtype=torch.int32, device=device)
|
||||
else:
|
||||
past_kv_seq_lengths = torch.randint(low=1, high=max_seq_len - 1, size=(bsz,), dtype=torch.int32, device=device)
|
||||
# past_kv_seq_lengths in this test records the previous kv seq len
|
||||
# (not incorporating the current input whose seq len is 1)
|
||||
past_kv_seq_lengths = (
|
||||
torch.tensor([max_seq_len - 1 for _ in range(bsz)], dtype=torch.int32, device=device)
|
||||
if same_context_len
|
||||
else torch.randint(low=1, high=max_seq_len - 1, size=(bsz,), dtype=torch.int32, device=device)
|
||||
)
|
||||
num_tokens = torch.sum(past_kv_seq_lengths).item()
|
||||
|
||||
kv_size = (num_tokens, 2 * num_kv_heads, head_dim)
|
||||
kv = torch.empty(size=kv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5)
|
||||
k, v = torch.split(kv, [num_kv_heads, num_kv_heads], dim=-2)
|
||||
kv_unpad = torch.empty(size=kv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5)
|
||||
k_unpad, v_unpad = torch.split(kv_unpad, [num_kv_heads, num_kv_heads], dim=-2)
|
||||
|
||||
cache_shape = (bsz * max_num_blocks_per_seq, num_kv_heads, head_dim, block_size)
|
||||
k_cache = torch.zeros(size=cache_shape, dtype=dtype, device=device)
|
||||
v_cache = torch.zeros(size=cache_shape, dtype=dtype, device=device)
|
||||
# Mock allocation on block tables as well as blocked kv caches
|
||||
block_tables = mock_alloc_block_table_and_kvcache(
|
||||
k, v, k_cache, v_cache, past_kv_seq_lengths, bsz, max_num_blocks_per_seq, block_size
|
||||
k_cache, _, block_tables = generate_caches_and_block_tables_v2(
|
||||
k_unpad, v_unpad, past_kv_seq_lengths, bsz, max_num_blocks_per_seq, block_size, dtype=dtype, device=device
|
||||
)
|
||||
block_tables = block_tables.to(device=device)
|
||||
|
||||
new_k = torch.randn((bsz, 1, num_kv_heads, head_dim), dtype=dtype, device=device)
|
||||
# mock allocating blocks for the new k/v and update block tables
|
||||
mock_alloc_single_token(block_tables, past_kv_seq_lengths, block_size)
|
||||
|
||||
# kv seq len = past kv seq len + seq len (1 during decoding stage)
|
||||
kv_seq_lengths = past_kv_seq_lengths + 1
|
||||
|
||||
@@ -78,7 +76,6 @@ def test_copy_kv_to_caches(
|
||||
torch.cuda.synchronize()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
head_dim = 128
|
||||
max_seq_len = block_size * max_num_blocks_per_seq
|
||||
dtype = torch.float16
|
||||
device = get_current_device()
|
||||
@@ -86,7 +83,7 @@ def test_copy_kv_to_caches(
|
||||
new_k, k_cache, kv_seq_lengths, block_tables = prepare_data(
|
||||
bsz,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
HEAD_DIM,
|
||||
block_size,
|
||||
max_num_blocks_per_seq,
|
||||
same_context_len,
|
||||
@@ -94,20 +91,28 @@ def test_copy_kv_to_caches(
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
# k_cache_torch = k_cache.clone().detach()
|
||||
# copy_to_cache(new_k, k_cache_torch, lengths=kv_seq_lengths, block_tables=block_tables, type="decoding")
|
||||
copy_kv_to_blocked_cache(new_k, k_cache, kv_seq_lengths, block_tables)
|
||||
|
||||
for seq_i in range(bsz):
|
||||
ki = new_k[seq_i]
|
||||
ki = ki.squeeze()
|
||||
past_kv_seq_len = kv_seq_lengths[seq_i] - 1
|
||||
target_block_id = block_tables[seq_i, past_kv_seq_len // block_size]
|
||||
offsets_in_block = past_kv_seq_len % block_size
|
||||
target = k_cache[target_block_id, :, :, offsets_in_block]
|
||||
orig = new_k[seq_i].squeeze(dim=0)
|
||||
assert torch.equal(orig, target)
|
||||
past_kv_seq_len = kv_seq_lengths - 1
|
||||
target_block_ids = block_tables[range(0, block_tables.size(0)), past_kv_seq_len // block_size]
|
||||
offsets_in_block = past_kv_seq_len % block_size
|
||||
target = k_cache[target_block_ids, :, offsets_in_block, :]
|
||||
source = new_k.squeeze()
|
||||
|
||||
assert target.shape == source.shape
|
||||
assert torch.equal(target, source)
|
||||
# target_torch = k_cache_copy[target_block_ids, :, offsets_in_block, :]
|
||||
# assert target_torch.shape == source.shape
|
||||
# assert torch.equal(target_torch, source)
|
||||
|
||||
|
||||
BATCH = 16
|
||||
BLOCK_SIZE = 32
|
||||
SAME_LEN = True
|
||||
WARM_UPS = 10
|
||||
REPS = 100
|
||||
configs = [
|
||||
triton.testing.Benchmark(
|
||||
x_names=["KV_SEQ_LEN"],
|
||||
@@ -133,10 +138,6 @@ def benchmark_kvcache_copy(
|
||||
num_kv_heads: int,
|
||||
same_context_len: bool,
|
||||
):
|
||||
warmup = 10
|
||||
rep = 100
|
||||
|
||||
head_dim = 128
|
||||
dtype = torch.float16
|
||||
device = get_current_device()
|
||||
|
||||
@@ -145,7 +146,7 @@ def benchmark_kvcache_copy(
|
||||
new_k, k_cache, context_lengths, block_tables = prepare_data(
|
||||
bsz,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
HEAD_DIM,
|
||||
block_size,
|
||||
max_seq_len // block_size,
|
||||
same_context_len,
|
||||
@@ -154,15 +155,14 @@ def benchmark_kvcache_copy(
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
if provider == "torch_copy_func":
|
||||
fn = lambda: copy_to_cache(new_k, k_cache, lengths=context_lengths, block_tables=block_tables, type="decoding")
|
||||
elif provider == "triton_copy_func":
|
||||
if provider == "triton_copy_func":
|
||||
fn = lambda: copy_kv_to_blocked_cache(new_k, k_cache, context_lengths, block_tables)
|
||||
else:
|
||||
raise ValueError("Undefined provider.")
|
||||
|
||||
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
|
||||
return ms
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(fn, warmup=WARM_UPS, rep=REPS, quantiles=quantiles)
|
||||
return ms, min_ms, max_ms
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@@ -3,7 +3,6 @@ import torch
|
||||
import triton
|
||||
from packaging import version
|
||||
from transformers.models.llama.modeling_llama import LlamaRMSNorm
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
|
||||
from colossalai.kernel.triton import rms_layernorm
|
||||
from colossalai.testing.utils import parameterize
|
||||
@@ -36,7 +35,8 @@ def test_layer_norm(M, N):
|
||||
y_triton = rms_layernorm(x, weight, eps=eps)
|
||||
y_llama = rms_norm.forward(x).to(dtype)
|
||||
|
||||
assert torch.allclose(y_triton, y_llama, atol=1e-5, rtol=1e-5)
|
||||
assert y_triton.shape == y_llama.shape
|
||||
assert torch.allclose(y_triton, y_llama, atol=1e-5, rtol=1e-3)
|
||||
|
||||
|
||||
# Triton benchmark plot attributions
|
||||
@@ -45,8 +45,8 @@ configs = [
|
||||
x_names=["SEQUENCE_TOTAL"],
|
||||
x_vals=[i for i in range(128, 1025, 128)],
|
||||
line_arg="provider",
|
||||
line_vals=["vllm_rms_layernorm", "triton_rms_layernorm"],
|
||||
line_names=["vllm_rms_layernorm", "triton_rms_layernorm"],
|
||||
line_vals=["torch_rms_layernorm", "triton_rms_layernorm"],
|
||||
line_names=["torch_rms_layernorm", "triton_rms_layernorm"],
|
||||
styles=[("red", "-"), ("blue", "-")],
|
||||
ylabel="ms",
|
||||
plot_name=f"RMSNorm benchmarking results",
|
||||
@@ -69,10 +69,10 @@ def benchmark_rms_layernorm(
|
||||
x_shape = (SEQUENCE_TOTAL, HIDDEN_SIZE)
|
||||
w_shape = (x_shape[-1],)
|
||||
weight = torch.ones(w_shape, dtype=dtype, device="cuda")
|
||||
vllm_norm = RMSNorm(hidden_size=HIDDEN_SIZE).to(dtype=dtype, device="cuda")
|
||||
torch_norm = LlamaRMSNorm(hidden_size=HIDDEN_SIZE).to(dtype=dtype, device="cuda")
|
||||
x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda")
|
||||
if provider == "vllm_rms_layernorm":
|
||||
fn = lambda: vllm_norm(x)
|
||||
if provider == "torch_rms_layernorm":
|
||||
fn = lambda: torch_norm(x)
|
||||
elif provider == "triton_rms_layernorm":
|
||||
fn = lambda: rms_layernorm(x, weight, eps=eps)
|
||||
else:
|
||||
|
Reference in New Issue
Block a user