mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-01 03:45:27 +00:00
[Inference/opt] Fused KVCahce Memcopy (#5374)
* fused kv memcopy * add TODO in test_kvcache_copy.py
This commit is contained in:
parent
58740b5f68
commit
6fb4bcbb24
@ -301,8 +301,9 @@ class NopadLlamaAttention(LlamaAttention):
|
|||||||
sm_scale=sm_scale,
|
sm_scale=sm_scale,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
copy_kv_to_blocked_cache(key_states, k_cache, kv_lengths=sequence_lengths, block_tables=block_tables)
|
copy_kv_to_blocked_cache(
|
||||||
copy_kv_to_blocked_cache(value_states, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables)
|
key_states, value_states, k_cache, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables
|
||||||
|
)
|
||||||
attn_output = flash_decoding_attention(
|
attn_output = flash_decoding_attention(
|
||||||
q=query_states,
|
q=query_states,
|
||||||
k_cache=k_cache,
|
k_cache=k_cache,
|
||||||
|
@ -356,8 +356,9 @@ class PadLlamaAttention(LlamaAttention):
|
|||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
attn_output = pad_input(attn_output, indices, bsz, q_len)
|
attn_output = pad_input(attn_output, indices, bsz, q_len)
|
||||||
else:
|
else:
|
||||||
copy_kv_to_blocked_cache(key_states, k_cache, kv_lengths=sequence_lengths, block_tables=block_tables)
|
copy_kv_to_blocked_cache(
|
||||||
copy_kv_to_blocked_cache(value_states, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables)
|
key_states, value_states, k_cache, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables
|
||||||
|
)
|
||||||
attn_output = flash_decoding_attention(
|
attn_output = flash_decoding_attention(
|
||||||
q=query_states,
|
q=query_states,
|
||||||
k_cache=k_cache,
|
k_cache=k_cache,
|
||||||
|
@ -6,17 +6,26 @@ import triton.language as tl
|
|||||||
# Triton 2.1.0
|
# Triton 2.1.0
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def _copy_to_kvcache_seqlen1_kernel(
|
def _copy_to_kvcache_seqlen1_kernel(
|
||||||
KV, # K or V
|
K, # K
|
||||||
KVCache, # KCache or VCache
|
V, # V
|
||||||
|
KCache, # KCache
|
||||||
|
VCache, # VCache
|
||||||
BLOCK_TABLES,
|
BLOCK_TABLES,
|
||||||
context_lengths,
|
context_lengths,
|
||||||
stride_kt,
|
stride_kt,
|
||||||
stride_kh,
|
stride_kh,
|
||||||
stride_kd,
|
stride_kd,
|
||||||
stride_cacheb,
|
stride_vt,
|
||||||
stride_cacheh,
|
stride_vh,
|
||||||
stride_cachebs,
|
stride_vd,
|
||||||
stride_cached,
|
stride_cachekb,
|
||||||
|
stride_cachekh,
|
||||||
|
stride_cachekbs,
|
||||||
|
stride_cachekd,
|
||||||
|
stride_cachevb,
|
||||||
|
stride_cachevh,
|
||||||
|
stride_cachevbs,
|
||||||
|
stride_cachevd,
|
||||||
stride_bts,
|
stride_bts,
|
||||||
stride_btb,
|
stride_btb,
|
||||||
block_size,
|
block_size,
|
||||||
@ -32,20 +41,33 @@ def _copy_to_kvcache_seqlen1_kernel(
|
|||||||
offsets_in_last_block = past_kv_seq_len % block_size
|
offsets_in_last_block = past_kv_seq_len % block_size
|
||||||
offsets_dmodel = tl.arange(0, HEAD_DIM)
|
offsets_dmodel = tl.arange(0, HEAD_DIM)
|
||||||
offsets_kv = cur_seq_idx * stride_kt + cur_kv_head_idx * stride_kh + offsets_dmodel * stride_kd
|
offsets_kv = cur_seq_idx * stride_kt + cur_kv_head_idx * stride_kh + offsets_dmodel * stride_kd
|
||||||
kv = tl.load(KV + offsets_kv)
|
|
||||||
|
k = tl.load(K + offsets_kv)
|
||||||
|
v = tl.load(V + offsets_kv)
|
||||||
|
|
||||||
offsets_kvcache = (
|
offsets_kvcache = (
|
||||||
block_id * stride_cacheb
|
block_id * stride_cachekb
|
||||||
+ cur_kv_head_idx * stride_cacheh
|
+ cur_kv_head_idx * stride_cachekh
|
||||||
+ offsets_in_last_block * stride_cachebs
|
+ offsets_in_last_block * stride_cachekbs
|
||||||
+ offsets_dmodel * stride_cached
|
+ offsets_dmodel * stride_cachekd
|
||||||
)
|
)
|
||||||
tl.store(KVCache + offsets_kvcache, kv)
|
offsets_kvcache = (
|
||||||
|
block_id * stride_cachevb
|
||||||
|
+ cur_kv_head_idx * stride_cachevh
|
||||||
|
+ offsets_in_last_block * stride_cachevbs
|
||||||
|
+ offsets_dmodel * stride_cachevd
|
||||||
|
)
|
||||||
|
|
||||||
|
tl.store(KCache + offsets_kvcache, k)
|
||||||
|
tl.store(VCache + offsets_kvcache, v)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
def copy_kv_to_blocked_cache(
|
def copy_kv_to_blocked_cache(
|
||||||
k: torch.Tensor,
|
k: torch.Tensor,
|
||||||
|
v: torch.Tensor,
|
||||||
k_cache: torch.Tensor,
|
k_cache: torch.Tensor,
|
||||||
|
v_cache: torch.Tensor,
|
||||||
kv_lengths: torch.Tensor,
|
kv_lengths: torch.Tensor,
|
||||||
block_tables: torch.Tensor,
|
block_tables: torch.Tensor,
|
||||||
):
|
):
|
||||||
@ -53,16 +75,23 @@ def copy_kv_to_blocked_cache(
|
|||||||
Copy keys or values to the blocked key/value cache during decoding stage.
|
Copy keys or values to the blocked key/value cache during decoding stage.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
k (torch.Tensor): [bsz, 1, num_kv_heads, head_dim]/[bsz, num_kv_heads, head_dim] - Keys or values during decoding with seq len 1.
|
k (torch.Tensor): [bsz, 1, num_kv_heads, head_dim]/[bsz, num_kv_heads, head_dim] - Keys during decoding with seq len 1.
|
||||||
k_cache (torch.Tensor): [num_blocks, num_kv_heads, block_size, head_dim] - Blocked key or value cache.
|
v (torch.Tensor): [bsz, 1, num_kv_heads, head_dim]/[bsz, num_kv_heads, head_dim] - Values during decoding with seq len 1.
|
||||||
|
k_cache (torch.Tensor): [num_blocks, num_kv_heads, block_size, head_dim] - Blocked key cache.
|
||||||
|
v_cache (torch.Tensor): [num_blocks, num_kv_heads, block_size, head_dim] - Blocked value cache.
|
||||||
kv_lengths (torch.Tensor): [bsz] - Past key/value sequence lengths plus current sequence length for each sequence.
|
kv_lengths (torch.Tensor): [bsz] - Past key/value sequence lengths plus current sequence length for each sequence.
|
||||||
block_tables (torch.Tensor): [bsz, max_blocks_per_sequence] - Block tables for each sequence.
|
block_tables (torch.Tensor): [bsz, max_blocks_per_sequence] - Block tables for each sequence.
|
||||||
"""
|
"""
|
||||||
assert k.size(-1) == k_cache.size(-1), "Incompatible head dim"
|
assert k.size(-1) == k_cache.size(-1), "Incompatible head dim"
|
||||||
assert k.dtype == k_cache.dtype, "Expected consistent dtype for tensor and cache."
|
assert k.dtype == k_cache.dtype, "Expected consistent dtype for tensor and cache."
|
||||||
|
|
||||||
k = k.squeeze(1) if k.dim() == 4 else k
|
k = k.squeeze(1) if k.dim() == 4 else k
|
||||||
assert k.dim() == 3, f"Incompatible k dim {k.dim()}"
|
assert k.dim() == 3, f"Incompatible k dim {k.dim()}"
|
||||||
|
|
||||||
|
assert v.size(-1) == v_cache.size(-1), "Incompatible head dim"
|
||||||
|
assert v.dtype == v_cache.dtype, "Expected consistent dtype for tensor and cache."
|
||||||
|
v = v.squeeze(1) if v.dim() == 4 else v
|
||||||
|
assert v.dim() == 3, f"Incompatible v dim {v.dim()}"
|
||||||
|
|
||||||
bsz, num_kv_heads, head_dim = k.shape
|
bsz, num_kv_heads, head_dim = k.shape
|
||||||
|
|
||||||
assert kv_lengths.shape[0] == block_tables.shape[0] == bsz, (
|
assert kv_lengths.shape[0] == block_tables.shape[0] == bsz, (
|
||||||
@ -75,20 +104,28 @@ def copy_kv_to_blocked_cache(
|
|||||||
block_size = k_cache.size(-2)
|
block_size = k_cache.size(-2)
|
||||||
|
|
||||||
num_warps = 8 if head_dim > 128 else 4
|
num_warps = 8 if head_dim > 128 else 4
|
||||||
|
|
||||||
grid = (bsz, num_kv_heads)
|
grid = (bsz, num_kv_heads)
|
||||||
_copy_to_kvcache_seqlen1_kernel[grid](
|
_copy_to_kvcache_seqlen1_kernel[grid](
|
||||||
k,
|
k,
|
||||||
|
v,
|
||||||
k_cache,
|
k_cache,
|
||||||
|
v_cache,
|
||||||
block_tables,
|
block_tables,
|
||||||
kv_lengths,
|
kv_lengths,
|
||||||
k.stride(0),
|
k.stride(0),
|
||||||
k.stride(1),
|
k.stride(1),
|
||||||
k.stride(2),
|
k.stride(2),
|
||||||
|
v.stride(0),
|
||||||
|
v.stride(1),
|
||||||
|
v.stride(2),
|
||||||
k_cache.stride(0),
|
k_cache.stride(0),
|
||||||
k_cache.stride(1),
|
k_cache.stride(1),
|
||||||
k_cache.stride(2),
|
k_cache.stride(2),
|
||||||
k_cache.stride(3),
|
k_cache.stride(3),
|
||||||
|
v_cache.stride(0),
|
||||||
|
v_cache.stride(1),
|
||||||
|
v_cache.stride(2),
|
||||||
|
v_cache.stride(3),
|
||||||
block_tables.stride(0),
|
block_tables.stride(0),
|
||||||
block_tables.stride(1),
|
block_tables.stride(1),
|
||||||
block_size,
|
block_size,
|
||||||
|
@ -44,18 +44,19 @@ def prepare_data(
|
|||||||
kv_unpad = torch.empty(size=kv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5)
|
kv_unpad = torch.empty(size=kv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5)
|
||||||
k_unpad, v_unpad = torch.split(kv_unpad, [num_kv_heads, num_kv_heads], dim=-2)
|
k_unpad, v_unpad = torch.split(kv_unpad, [num_kv_heads, num_kv_heads], dim=-2)
|
||||||
|
|
||||||
k_cache, _, block_tables = generate_caches_and_block_tables_v2(
|
k_cache, v_cache, block_tables = generate_caches_and_block_tables_v2(
|
||||||
k_unpad, v_unpad, past_kv_seq_lengths, bsz, max_num_blocks_per_seq, block_size, dtype=dtype, device=device
|
k_unpad, v_unpad, past_kv_seq_lengths, bsz, max_num_blocks_per_seq, block_size, dtype=dtype, device=device
|
||||||
)
|
)
|
||||||
block_tables = block_tables.to(device=device)
|
block_tables = block_tables.to(device=device)
|
||||||
|
|
||||||
new_k = torch.randn((bsz, 1, num_kv_heads, head_dim), dtype=dtype, device=device)
|
new_k = torch.randn((bsz, 1, num_kv_heads, head_dim), dtype=dtype, device=device)
|
||||||
|
new_v = torch.randn((bsz, 1, num_kv_heads, head_dim), dtype=dtype, device=device)
|
||||||
# mock allocating blocks for the new k/v and update block tables
|
# mock allocating blocks for the new k/v and update block tables
|
||||||
mock_alloc_single_token(block_tables, past_kv_seq_lengths, block_size)
|
mock_alloc_single_token(block_tables, past_kv_seq_lengths, block_size)
|
||||||
# kv seq len = past kv seq len + seq len (1 during decoding stage)
|
# kv seq len = past kv seq len + seq len (1 during decoding stage)
|
||||||
kv_seq_lengths = past_kv_seq_lengths + 1
|
kv_seq_lengths = past_kv_seq_lengths + 1
|
||||||
|
|
||||||
return new_k, k_cache, kv_seq_lengths, block_tables
|
return new_k, new_v, k_cache, v_cache, kv_seq_lengths, block_tables
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not (HAS_TRITON and TRITON_CUDA_SUPPORT), reason="requires triton")
|
@pytest.mark.skipif(not (HAS_TRITON and TRITON_CUDA_SUPPORT), reason="requires triton")
|
||||||
@ -80,7 +81,7 @@ def test_copy_kv_to_caches(
|
|||||||
dtype = torch.float16
|
dtype = torch.float16
|
||||||
device = get_current_device()
|
device = get_current_device()
|
||||||
|
|
||||||
new_k, k_cache, kv_seq_lengths, block_tables = prepare_data(
|
new_k, new_v, k_cache, v_cache, kv_seq_lengths, block_tables = prepare_data(
|
||||||
bsz,
|
bsz,
|
||||||
num_kv_heads,
|
num_kv_heads,
|
||||||
HEAD_DIM,
|
HEAD_DIM,
|
||||||
@ -93,16 +94,20 @@ def test_copy_kv_to_caches(
|
|||||||
)
|
)
|
||||||
# k_cache_torch = k_cache.clone().detach()
|
# k_cache_torch = k_cache.clone().detach()
|
||||||
# copy_to_cache(new_k, k_cache_torch, lengths=kv_seq_lengths, block_tables=block_tables, type="decoding")
|
# copy_to_cache(new_k, k_cache_torch, lengths=kv_seq_lengths, block_tables=block_tables, type="decoding")
|
||||||
copy_kv_to_blocked_cache(new_k, k_cache, kv_seq_lengths, block_tables)
|
copy_kv_to_blocked_cache(new_k, new_v, k_cache, v_cache, kv_seq_lengths, block_tables)
|
||||||
|
|
||||||
past_kv_seq_len = kv_seq_lengths - 1
|
past_kv_seq_len = kv_seq_lengths - 1
|
||||||
target_block_ids = block_tables[range(0, block_tables.size(0)), past_kv_seq_len // block_size]
|
target_block_ids = block_tables[range(0, block_tables.size(0)), past_kv_seq_len // block_size]
|
||||||
offsets_in_block = past_kv_seq_len % block_size
|
offsets_in_block = past_kv_seq_len % block_size
|
||||||
target = k_cache[target_block_ids, :, offsets_in_block, :]
|
k_target = k_cache[target_block_ids, :, offsets_in_block, :]
|
||||||
source = new_k.squeeze()
|
k_source = new_k.squeeze()
|
||||||
|
v_target = v_cache[target_block_ids, :, offsets_in_block, :]
|
||||||
|
v_source = new_v.squeeze()
|
||||||
|
|
||||||
assert target.shape == source.shape
|
assert k_target.shape == k_source.shape
|
||||||
assert torch.equal(target, source)
|
assert torch.equal(k_target, k_source)
|
||||||
|
assert v_target.shape == v_source.shape
|
||||||
|
assert torch.equal(v_target, v_source)
|
||||||
# target_torch = k_cache_copy[target_block_ids, :, offsets_in_block, :]
|
# target_torch = k_cache_copy[target_block_ids, :, offsets_in_block, :]
|
||||||
# assert target_torch.shape == source.shape
|
# assert target_torch.shape == source.shape
|
||||||
# assert torch.equal(target_torch, source)
|
# assert torch.equal(target_torch, source)
|
||||||
@ -143,7 +148,7 @@ def benchmark_kvcache_copy(
|
|||||||
|
|
||||||
assert KV_SEQ_LEN <= max_seq_len, "Assigned maximum kv length must be smaller or equal to maximum seq len"
|
assert KV_SEQ_LEN <= max_seq_len, "Assigned maximum kv length must be smaller or equal to maximum seq len"
|
||||||
|
|
||||||
new_k, k_cache, context_lengths, block_tables = prepare_data(
|
new_k, new_v, k_cache, v_cache, context_lengths, block_tables = prepare_data(
|
||||||
bsz,
|
bsz,
|
||||||
num_kv_heads,
|
num_kv_heads,
|
||||||
HEAD_DIM,
|
HEAD_DIM,
|
||||||
@ -156,10 +161,11 @@ def benchmark_kvcache_copy(
|
|||||||
)
|
)
|
||||||
|
|
||||||
quantiles = [0.5, 0.2, 0.8]
|
quantiles = [0.5, 0.2, 0.8]
|
||||||
|
# TODO copy_to_cache needs to support copying both k and v at the same time in the future.
|
||||||
if provider == "torch_copy_func":
|
if provider == "torch_copy_func":
|
||||||
fn = lambda: copy_to_cache(new_k, k_cache, lengths=context_lengths, block_tables=block_tables, type="decoding")
|
fn = lambda: copy_to_cache(new_k, k_cache, lengths=context_lengths, block_tables=block_tables, type="decoding")
|
||||||
if provider == "triton_copy_func":
|
if provider == "triton_copy_func":
|
||||||
fn = lambda: copy_kv_to_blocked_cache(new_k, k_cache, context_lengths, block_tables)
|
fn = lambda: copy_kv_to_blocked_cache(new_k, new_v, k_cache, v_cache, context_lengths, block_tables)
|
||||||
|
|
||||||
ms, min_ms, max_ms = triton.testing.do_bench(fn, warmup=WARM_UPS, rep=REPS, quantiles=quantiles)
|
ms, min_ms, max_ms = triton.testing.do_bench(fn, warmup=WARM_UPS, rep=REPS, quantiles=quantiles)
|
||||||
return ms, min_ms, max_ms
|
return ms, min_ms, max_ms
|
||||||
|
Loading…
Reference in New Issue
Block a user