diff --git a/colossalai/inference/modeling/layers/attention.py b/colossalai/inference/modeling/layers/attention.py index af4395f4b..e1bd935e9 100644 --- a/colossalai/inference/modeling/layers/attention.py +++ b/colossalai/inference/modeling/layers/attention.py @@ -31,7 +31,7 @@ def copy_to_cache(source, cache, lengths, block_tables, type: str = "prefill"): 1, 2, 0 ) elif type == "decoding": - assert len(source[0]) == 1, "seq_len should be equal to 1 when decoding." + assert source.size(1) == 1, "seq_len should be equal to 1 when decoding." source = source.squeeze(1) slot_idx = (lengths + block_size - 1) % block_size for i in range(bsz): @@ -314,4 +314,4 @@ class PagedAttention: ): return self.pad_decoding_forward( q.unsqueeze(1), k.unsqueeze(1), v.unsqueeze(1), k_cache, v_cache, lengths, block_tables - ) \ No newline at end of file + ) diff --git a/colossalai/kernel/triton/__init__.py b/colossalai/kernel/triton/__init__.py index 4ac71ac64..021ccb9c1 100644 --- a/colossalai/kernel/triton/__init__.py +++ b/colossalai/kernel/triton/__init__.py @@ -12,12 +12,14 @@ if HAS_TRITON: from .flash_decoding import flash_decoding_fwd from .fused_layernorm import layer_norm from .gptq_triton import gptq_fused_linear_triton + from .kvcache_copy import copy_kv_to_blocked_cache from .no_pad_rotary_embedding import rotary_embedding from .softmax import softmax __all__ = [ "context_attention_unpadded", "flash_decoding_fwd", + "copy_kv_to_blocked_cache", "softmax", "layer_norm", "gptq_fused_linear_triton", diff --git a/colossalai/kernel/triton/kvcache_copy.py b/colossalai/kernel/triton/kvcache_copy.py new file mode 100644 index 000000000..b979e24cd --- /dev/null +++ b/colossalai/kernel/triton/kvcache_copy.py @@ -0,0 +1,90 @@ +import torch +import triton +import triton.language as tl + + +# Triton 2.1.0 +@triton.jit +def _copy_to_kvcache_seqlen1_kernel( + KV, # K or V + KVCache, # KCache or VCache + BLOCK_TABLES, + context_lengths, + stride_kt, + stride_kh, + stride_kd, + stride_cacheb, + stride_cacheh, + stride_cached, + stride_cachebs, + stride_bts, + stride_btb, + block_size, + HEAD_DIM: tl.constexpr, +): + cur_seq_idx = tl.program_id(0) + cur_kv_head_idx = tl.program_id(1) + + cur_kv_seq_len = tl.load(context_lengths + cur_seq_idx) + last_bt_block_idx = cur_kv_seq_len // block_size + block_table_ptr = BLOCK_TABLES + cur_seq_idx * stride_bts + block_id = tl.load(block_table_ptr + last_bt_block_idx * stride_btb) + offsets_in_last_block = (cur_kv_seq_len % block_size) * stride_cachebs + offsets_dmodel = tl.arange(0, HEAD_DIM) + offsets_kv = cur_seq_idx * stride_kt + cur_kv_head_idx * stride_kh + offsets_dmodel * stride_kd + kv = tl.load(KV + offsets_kv) + offsets_kvcache = ( + block_id * stride_cacheb + + cur_kv_head_idx * stride_cacheh + + offsets_dmodel * stride_cached + + offsets_in_last_block + ) + tl.store(KVCache + offsets_kvcache, kv) + return + + +# Used with blocked kv cache. +# Copy k or v to block k/v cache during decoding stage +def copy_kv_to_blocked_cache( + k: torch.Tensor, # [bsz, 1, num_kv_heads, head_dim], k or v during decoding stage + k_cache: torch.Tensor, # [num_blocks, num_kv_heads, head_dim, block_size], blocked k or v cache (for now, the shapes of them are the same) + context_lengths: torch.Tensor, # [bsz], past kv seq len (not incorporating the current kv of length 1) + block_tables: torch.Tensor, # [bsz, max_blocks_per_sequence] +): + assert k.dim() == 4, "Unsupported shape of k (supposed to be used for decoding stage)" + assert k.size(1) == 1, "Unsupported kv seq len (supposed to be used for decoding stage)" + assert k.size(-1) == k_cache.size(-2), "Incompatible head dim" + assert k.dtype == k_cache.dtype, "Expected consistent dtype for tensor and cache." + bsz, _, num_kv_heads, head_dim = k.shape + assert context_lengths.shape[0] == block_tables.shape[0] == bsz, ( + f"Got incompatible batch size (number of seqs):\n" + f" Conext lengths bsz {context_lengths.shape[0]}, Block tables bsz {block_tables.shape[0]}, " + f"batch size {bsz}" + ) + + # Modify if the shape of kv cahce is changed. + block_size = k_cache.size(-1) + # [bsz, 1, num_kv_heads, head_dim] -> [bsz, num_kv_heads, head_dim] + k = k.squeeze(dim=1) + + num_warps = 8 if head_dim > 128 else 4 + + grid = (bsz, num_kv_heads) + _copy_to_kvcache_seqlen1_kernel[grid]( + k, + k_cache, + block_tables, + context_lengths, + k.stride(0), + k.stride(1), + k.stride(2), + k_cache.stride(0), + k_cache.stride(1), + k_cache.stride(2), + k_cache.stride(3), + block_tables.stride(0), + block_tables.stride(1), + block_size, + HEAD_DIM=head_dim, + num_warps=num_warps, + ) diff --git a/tests/test_infer_ops/triton/kernel_utils.py b/tests/test_infer_ops/triton/kernel_utils.py index 2f34c5463..3cd897931 100644 --- a/tests/test_infer_ops/triton/kernel_utils.py +++ b/tests/test_infer_ops/triton/kernel_utils.py @@ -100,3 +100,29 @@ def mock_alloc_block_table_and_kvcache( block_id += 1 return block_tables + + +def mock_alloc_single_token(block_tables: torch.Tensor, context_lengths: torch.Tensor, block_size: int): + """Allocate 1 token on the block table for each seqs in block tables. + It won't change provided context_lengths + """ + + # consider max_block_id as the last physical block allocated + # NOTE It assumes all the blocks preceding this block have been allocated + max_block_id = torch.max(block_tables).item() + # the indices on each block table representing the cache block to be allocated one more token + alloc_local_block_indices = context_lengths // block_size + # offsets of the token to be allocated on the target block (for each seq) + alloc_block_offsets = context_lengths % block_size + + require_new_block = alloc_block_offsets == 0 + new_block_ids = torch.arange( + max_block_id + 1, + max_block_id + 1 + require_new_block.sum(), + dtype=block_tables.dtype, + device=block_tables.device, + ) + + if new_block_ids.numel(): + new_block_alloc_local_indices = alloc_local_block_indices[require_new_block] + block_tables[require_new_block, new_block_alloc_local_indices] = new_block_ids diff --git a/tests/test_infer_ops/triton/test_kvcache_copy.py b/tests/test_infer_ops/triton/test_kvcache_copy.py new file mode 100644 index 000000000..875c34fba --- /dev/null +++ b/tests/test_infer_ops/triton/test_kvcache_copy.py @@ -0,0 +1,168 @@ +import pytest +import torch +from packaging import version + +from colossalai.inference.modeling.layers.attention import copy_to_cache +from colossalai.kernel.triton import copy_kv_to_blocked_cache +from colossalai.utils import get_current_device +from tests.test_infer_ops.triton.kernel_utils import mock_alloc_block_table_and_kvcache, mock_alloc_single_token + +try: + import triton # noqa + + HAS_TRITON = True +except ImportError: + HAS_TRITON = False + print("please install triton from https://github.com/openai/triton") + +TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4") + + +def prepare_data( + bsz, + num_kv_heads, + head_dim, + block_size, + max_num_blocks_per_seq, + same_context_len, + max_seq_len, + device, + dtype=torch.float16, +): + if same_context_len: + # context_lengths in this test records the previous kv seq len + # (not incorporating the current input whose seq len is 1) + context_lengths = torch.tensor([max_seq_len - 1 for _ in range(bsz)], dtype=torch.int32, device=device) + else: + context_lengths = torch.randint(low=1, high=max_seq_len - 1, size=(bsz,), dtype=torch.int32, device=device) + num_tokens = torch.sum(context_lengths).item() + + kv_size = (num_tokens, 2 * num_kv_heads, head_dim) + kv = torch.empty(size=kv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5) + k, v = torch.split(kv, [num_kv_heads, num_kv_heads], dim=-2) + + cache_shape = (bsz * max_num_blocks_per_seq, num_kv_heads, head_dim, block_size) + k_cache = torch.zeros(size=cache_shape, dtype=dtype, device=device) + v_cache = torch.zeros(size=cache_shape, dtype=dtype, device=device) + # Mock allocation on block tables as well as blocked kv caches + block_tables = mock_alloc_block_table_and_kvcache( + k, v, k_cache, v_cache, context_lengths, bsz, max_num_blocks_per_seq, block_size + ) + block_tables = block_tables.to(device=device) + + new_k = torch.randn((bsz, 1, num_kv_heads, head_dim), dtype=dtype, device=device) + # mock allocating blocks for the new k/v and update block tables + mock_alloc_single_token(block_tables, context_lengths, block_size) + + return new_k, k_cache, context_lengths, block_tables + + +@pytest.mark.skipif(not (HAS_TRITON and TRITON_CUDA_SUPPORT), reason="requires triton") +@pytest.mark.parametrize("bsz", [4, 7, 32]) +@pytest.mark.parametrize("block_size", [16, 32, 64]) +@pytest.mark.parametrize("max_num_blocks_per_seq", [8, 32]) +@pytest.mark.parametrize("num_kv_heads", [16]) +@pytest.mark.parametrize("same_context_len", [True, False]) +def test_copy_kv_to_caches( + bsz: int, + block_size: int, + max_num_blocks_per_seq: int, + num_kv_heads: int, + same_context_len: bool, +): + torch.manual_seed(123) + torch.cuda.empty_cache() + torch.cuda.synchronize() + torch.cuda.reset_peak_memory_stats() + + head_dim = 128 + max_seq_len = block_size * max_num_blocks_per_seq + dtype = torch.float16 + device = get_current_device() + + new_k, k_cache, context_lengths, block_tables = prepare_data( + bsz, + num_kv_heads, + head_dim, + block_size, + max_num_blocks_per_seq, + same_context_len, + max_seq_len, + device=device, + dtype=dtype, + ) + + copy_kv_to_blocked_cache(new_k, k_cache, context_lengths, block_tables) + + for seq_i in range(bsz): + ki = new_k[seq_i] + ki = ki.squeeze() + context_len_i = context_lengths[seq_i] + target_block_id = block_tables[seq_i, context_len_i // block_size] + offsets_in_block = context_len_i % block_size + target = k_cache[target_block_id, :, :, offsets_in_block] + orig = new_k[seq_i].squeeze(dim=0) + assert torch.equal(orig, target) + + +BATCH = 4 +configs = [ + triton.testing.Benchmark( + x_names=["PAST_KVLEN"], + x_vals=[2**i - 1 for i in range(8, 13)], + line_arg="provider", + line_vals=["torch_copy_func", "triton_copy_func"], + line_names=["torch_copy_func", "triton_copy_func"], + styles=[("red", "-"), ("blue", "-")], + ylabel="ms", + plot_name=f"kvcache_copy_decoding_stage-batch-{BATCH}", + args={"bsz": BATCH, "block_size": 16, "max_seq_len": 8192, "num_kv_heads": 16, "same_context_len": True}, + ) +] + + +@triton.testing.perf_report(configs) +def benchmark_kvcache_copy( + provider: str, + bsz: int, + block_size: int, + max_seq_len: int, + PAST_KVLEN: int, # maximum past kv length (unequal context lens in batch) or past kv len (equal context lens) + num_kv_heads: int, + same_context_len: bool, +): + warmup = 10 + rep = 100 + + head_dim = 128 + dtype = torch.float16 + device = get_current_device() + + assert PAST_KVLEN < max_seq_len, "Assigned maximum past kv length must be smaller or equal to maximum seq len" + + new_k, k_cache, context_lengths, block_tables = prepare_data( + bsz, + num_kv_heads, + head_dim, + block_size, + max_seq_len // block_size, + same_context_len, + PAST_KVLEN, + device=device, + dtype=dtype, + ) + + if provider == "torch_copy_func": + fn = lambda: copy_to_cache(new_k, k_cache, lengths=context_lengths, block_tables=block_tables, type="decoding") + elif provider == "triton_copy_func": + fn = lambda: copy_kv_to_blocked_cache(new_k, k_cache, context_lengths, block_tables) + else: + raise ValueError("Undefined provider.") + + ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) + return ms + + +if __name__ == "__main__": + test_copy_kv_to_caches(4, 32, 8, 16, False) + # benchmark_kvcache_copy.run(save_path=".")