mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-19 20:23:41 +00:00
[Infer] Optimize Blocked KVCache And Kernels Using It (#5325)
* revise shape of kvcache (context attn kernel) * revise shape of kvcache (flash decoding kernel) * revise shape of kvcache (kvcache copy) and attn func * init of kvcache in kvcache manager * revise llama modeling * revise block size retrieval * use torch for rms_norm benchmarking * revise block size retrieval
This commit is contained in:
parent
e8f0642f28
commit
5f98a9d68a
@ -79,10 +79,10 @@ class KVCacheManager:
|
|||||||
self.num_blocks = self.max_blocks_per_sequence * self.max_batch_size * self.beam_width
|
self.num_blocks = self.max_blocks_per_sequence * self.max_batch_size * self.beam_width
|
||||||
|
|
||||||
# Physical cache allocation
|
# Physical cache allocation
|
||||||
|
alloc_shape = (self.num_blocks, self.head_num, self.block_size, self.head_size)
|
||||||
if verbose:
|
if verbose:
|
||||||
alloc_shape = (self.num_blocks, self.head_num, self.head_size, self.block_size)
|
|
||||||
self.logger.info(f"Allocating KV cache with shape: {alloc_shape} consisting of {self.num_blocks} blocks.")
|
self.logger.info(f"Allocating KV cache with shape: {alloc_shape} consisting of {self.num_blocks} blocks.")
|
||||||
self._kv_caches = self._init_device_caches()
|
self._kv_caches = self._init_device_caches(alloc_shape)
|
||||||
self.total_physical_cache_size_in_bytes = (
|
self.total_physical_cache_size_in_bytes = (
|
||||||
self.elem_size_in_bytes
|
self.elem_size_in_bytes
|
||||||
* self.num_layers
|
* self.num_layers
|
||||||
@ -297,15 +297,12 @@ class KVCacheManager:
|
|||||||
blocks.append(cache_block)
|
blocks.append(cache_block)
|
||||||
return blocks
|
return blocks
|
||||||
|
|
||||||
def _init_device_caches(self) -> Tuple[torch.Tensor, torch.Tensor]:
|
def _init_device_caches(self, alloc_shape: Tuple[int, ...]) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""Initialize the physical cache on the device.
|
"""Initialize the physical cache on the device.
|
||||||
|
|
||||||
For each layer of the model, we allocate two tensors for key and value respectively,
|
For each layer of the model, we allocate two tensors for key and value respectively,
|
||||||
with shape of [num_blocks, num_kv_heads, head_size, block_size]
|
with shape of [num_blocks, num_kv_heads, block_size, head_size]
|
||||||
"""
|
"""
|
||||||
alloc_shape = (self.num_blocks, self.head_num, self.head_size, self.block_size)
|
|
||||||
# TODO: Explore the performance when using difference shapes with kernel-related optimizations
|
|
||||||
# e.g. [num_blocks, num_kv_heads // x, head_size, block_size, x]
|
|
||||||
k_cache: List[torch.Tensor] = []
|
k_cache: List[torch.Tensor] = []
|
||||||
v_cache: List[torch.Tensor] = []
|
v_cache: List[torch.Tensor] = []
|
||||||
for _ in range(self.num_layers):
|
for _ in range(self.num_layers):
|
||||||
|
@ -16,7 +16,7 @@ def copy_to_cache(source, cache, lengths, block_tables, type: str = "prefill"):
|
|||||||
lengths: key/value lengths
|
lengths: key/value lengths
|
||||||
block_tables
|
block_tables
|
||||||
"""
|
"""
|
||||||
num_blocks, num_heads, head_size, block_size = cache.shape
|
num_blocks, num_heads, block_size, head_size = cache.shape
|
||||||
bsz, max_blocks_per_seq = block_tables.shape
|
bsz, max_blocks_per_seq = block_tables.shape
|
||||||
needed_blocks = (lengths + block_size - 1) // block_size
|
needed_blocks = (lengths + block_size - 1) // block_size
|
||||||
|
|
||||||
@ -26,17 +26,17 @@ def copy_to_cache(source, cache, lengths, block_tables, type: str = "prefill"):
|
|||||||
block_num = needed_blocks[i]
|
block_num = needed_blocks[i]
|
||||||
token_id = 0
|
token_id = 0
|
||||||
for block_idx in range(block_num - 1):
|
for block_idx in range(block_num - 1):
|
||||||
cache[block_tables[i][block_idx]] = source[i][token_id : token_id + block_size].permute(1, 2, 0)
|
cache[block_tables[i][block_idx]] = source[i][token_id : token_id + block_size].permute(1, 0, 2)
|
||||||
token_id += block_size
|
token_id += block_size
|
||||||
cache[block_tables[i][block_num - 1], :, :, : seq_len - token_id] = source[i][token_id:seq_len].permute(
|
cache[block_tables[i][block_num - 1], :, : seq_len - token_id, :] = source[i][token_id:seq_len].permute(
|
||||||
1, 2, 0
|
1, 0, 2
|
||||||
)
|
)
|
||||||
elif type == "decoding":
|
elif type == "decoding":
|
||||||
assert source.size(1) == 1, "seq_len should be equal to 1 when decoding."
|
assert source.size(1) == 1, "seq_len should be equal to 1 when decoding."
|
||||||
source = source.squeeze(1)
|
source = source.squeeze(1)
|
||||||
slot_idx = (lengths + block_size - 1) % block_size
|
slot_idx = (lengths + block_size - 1) % block_size
|
||||||
for i in range(bsz):
|
for i in range(bsz):
|
||||||
cache[block_tables[i, needed_blocks[i] - 1], :, :, slot_idx[i]] = source[i]
|
cache[block_tables[i, needed_blocks[i] - 1], :, slot_idx[i], :] = source[i]
|
||||||
|
|
||||||
return cache
|
return cache
|
||||||
|
|
||||||
@ -46,12 +46,12 @@ def convert_kvcache(cache, lengths, block_tables, pad_id=0):
|
|||||||
"""
|
"""
|
||||||
Func: convert key/value cache for calculation
|
Func: convert key/value cache for calculation
|
||||||
|
|
||||||
Args: cache: shape [num_blocks, num_heads, head_size, block_size]
|
Args: cache: shape [num_blocks, num_heads, block_size, head_size]
|
||||||
lengths: key/value length
|
lengths: key/value length
|
||||||
block_tables
|
block_tables
|
||||||
pad_id: padded_id
|
pad_id: padded_id
|
||||||
"""
|
"""
|
||||||
num_blocks, num_heads, head_size, block_size = cache.shape
|
num_blocks, num_heads, block_size, head_size = cache.shape
|
||||||
|
|
||||||
needed_blocks = (lengths + block_size - 1) // block_size
|
needed_blocks = (lengths + block_size - 1) // block_size
|
||||||
num_remaing_tokens = lengths % block_size
|
num_remaing_tokens = lengths % block_size
|
||||||
@ -62,8 +62,8 @@ def convert_kvcache(cache, lengths, block_tables, pad_id=0):
|
|||||||
for i in range(bsz):
|
for i in range(bsz):
|
||||||
_cache = torch.cat(
|
_cache = torch.cat(
|
||||||
(
|
(
|
||||||
cache[block_tables[i][: needed_blocks[i] - 1]].permute((0, 3, 1, 2)).reshape(-1, num_heads, head_size),
|
cache[block_tables[i][: needed_blocks[i] - 1]].permute((0, 2, 1, 3)).reshape(-1, num_heads, head_size),
|
||||||
cache[block_tables[i][needed_blocks[i] - 1], :, :, : num_remaing_tokens[i]].permute(2, 0, 1),
|
cache[block_tables[i][needed_blocks[i] - 1], :, : num_remaing_tokens[i], :].permute(1, 0, 2),
|
||||||
),
|
),
|
||||||
dim=0,
|
dim=0,
|
||||||
)
|
)
|
||||||
@ -127,7 +127,7 @@ class PagedAttention:
|
|||||||
q: torch.Tensor, # [num_tokens, num_heads, head_size]
|
q: torch.Tensor, # [num_tokens, num_heads, head_size]
|
||||||
k: torch.Tensor, # [num_tokens, num_kv_heads, head_size]
|
k: torch.Tensor, # [num_tokens, num_kv_heads, head_size]
|
||||||
v: torch.Tensor,
|
v: torch.Tensor,
|
||||||
k_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size]
|
k_cache: torch.Tensor, # [num_blocks, num_heads, block_size, head_size]
|
||||||
v_cache: torch.Tensor,
|
v_cache: torch.Tensor,
|
||||||
context_lengths: torch.Tensor, # [num_seqs]
|
context_lengths: torch.Tensor, # [num_seqs]
|
||||||
block_tables: torch.Tensor, # [num_seqs,max_blocks_per_sequence]
|
block_tables: torch.Tensor, # [num_seqs,max_blocks_per_sequence]
|
||||||
@ -142,7 +142,7 @@ class PagedAttention:
|
|||||||
assert num_heads % num_kv_heads == 0, "num_kv_heads should be divisible by num_heads"
|
assert num_heads % num_kv_heads == 0, "num_kv_heads should be divisible by num_heads"
|
||||||
num_kv_groups = num_heads // num_kv_heads
|
num_kv_groups = num_heads // num_kv_heads
|
||||||
|
|
||||||
block_size = k_cache.shape[-1]
|
block_size = k_cache.size(-2)
|
||||||
bsz, max_blocks_per_sequence = block_tables.shape
|
bsz, max_blocks_per_sequence = block_tables.shape
|
||||||
max_seq_len = max_blocks_per_sequence * block_size
|
max_seq_len = max_blocks_per_sequence * block_size
|
||||||
assert q.shape[-1] == k.shape[-1] == v.shape[-1]
|
assert q.shape[-1] == k.shape[-1] == v.shape[-1]
|
||||||
@ -196,7 +196,7 @@ class PagedAttention:
|
|||||||
q: torch.Tensor, # [batch_size, seq_len, num_heads, head_size]
|
q: torch.Tensor, # [batch_size, seq_len, num_heads, head_size]
|
||||||
k: torch.Tensor, # [batch_size, seq_len, num_kv_heads, head_size]
|
k: torch.Tensor, # [batch_size, seq_len, num_kv_heads, head_size]
|
||||||
v: torch.Tensor,
|
v: torch.Tensor,
|
||||||
k_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size]
|
k_cache: torch.Tensor, # [num_blocks, num_heads, block_size, head_size]
|
||||||
v_cache: torch.Tensor,
|
v_cache: torch.Tensor,
|
||||||
context_lengths: torch.Tensor, # [num_seqs]
|
context_lengths: torch.Tensor, # [num_seqs]
|
||||||
block_tables: torch.Tensor, # [num_seqs,max_blocks_per_sequence]
|
block_tables: torch.Tensor, # [num_seqs,max_blocks_per_sequence]
|
||||||
@ -207,7 +207,7 @@ class PagedAttention:
|
|||||||
num_kv_heads = k.shape[-2]
|
num_kv_heads = k.shape[-2]
|
||||||
assert num_heads % num_kv_heads == 0, "num_kv_heads should be divisible by num_heads"
|
assert num_heads % num_kv_heads == 0, "num_kv_heads should be divisible by num_heads"
|
||||||
num_kv_groups = num_heads // num_kv_heads
|
num_kv_groups = num_heads // num_kv_heads
|
||||||
block_size = k_cache.shape[-1]
|
block_size = k_cache.size(-2)
|
||||||
assert q.shape[0] == k.shape[0] == v.shape[0] == block_tables.shape[0]
|
assert q.shape[0] == k.shape[0] == v.shape[0] == block_tables.shape[0]
|
||||||
block_tables.shape[-1] * block_size
|
block_tables.shape[-1] * block_size
|
||||||
|
|
||||||
@ -254,7 +254,7 @@ class PagedAttention:
|
|||||||
q: torch.Tensor, # [bsz, 1, num_heads, head_size]
|
q: torch.Tensor, # [bsz, 1, num_heads, head_size]
|
||||||
k: torch.Tensor, # [bsz, 1, num_kv_heads, head_size]
|
k: torch.Tensor, # [bsz, 1, num_kv_heads, head_size]
|
||||||
v: torch.Tensor,
|
v: torch.Tensor,
|
||||||
k_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size]
|
k_cache: torch.Tensor, # [num_blocks, num_heads, block_size, head_size]
|
||||||
v_cache: torch.Tensor,
|
v_cache: torch.Tensor,
|
||||||
lengths: torch.Tensor, # [num_seqs]: input_lengths + output_lengths
|
lengths: torch.Tensor, # [num_seqs]: input_lengths + output_lengths
|
||||||
block_tables: torch.Tensor, # [num_seqs,max_blocks_per_sequence]
|
block_tables: torch.Tensor, # [num_seqs,max_blocks_per_sequence]
|
||||||
|
@ -171,7 +171,7 @@ def llama_attn_forward(
|
|||||||
|
|
||||||
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
|
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
|
||||||
|
|
||||||
_, _, _, block_size = k_cache.shape
|
block_size = k_cache.size(-2)
|
||||||
|
|
||||||
if is_prompts:
|
if is_prompts:
|
||||||
attn_output = context_attention_unpadded(
|
attn_output = context_attention_unpadded(
|
||||||
|
@ -226,7 +226,7 @@ def llama_attn_forward(
|
|||||||
|
|
||||||
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
|
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
|
||||||
|
|
||||||
_, _, _, block_size = k_cache.shape
|
block_size = k_cache.size(-2)
|
||||||
|
|
||||||
if is_prompts:
|
if is_prompts:
|
||||||
attn_output = context_attention_unpadded(
|
attn_output = context_attention_unpadded(
|
||||||
|
@ -36,8 +36,8 @@ def _fwd_context_paged_attention_kernel(
|
|||||||
stride_od,
|
stride_od,
|
||||||
stride_cacheb,
|
stride_cacheb,
|
||||||
stride_cacheh,
|
stride_cacheh,
|
||||||
stride_cached,
|
|
||||||
stride_cachebs,
|
stride_cachebs,
|
||||||
|
stride_cached,
|
||||||
stride_bts,
|
stride_bts,
|
||||||
stride_btb,
|
stride_btb,
|
||||||
context_lengths,
|
context_lengths,
|
||||||
@ -158,29 +158,29 @@ def _fwd_context_paged_attention_kernel(
|
|||||||
# Copy k to corresponding cache block
|
# Copy k to corresponding cache block
|
||||||
offsets_dmodel = tl.arange(0, HEAD_DIM)
|
offsets_dmodel = tl.arange(0, HEAD_DIM)
|
||||||
offsets_kt = block_start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
offsets_kt = block_start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||||
offsets_k = K + offset_kv + offsets_dmodel[:, None] * stride_kd + offsets_kt[None, :] * stride_kt
|
offsets_k = K + offset_kv + offsets_dmodel[None, :] * stride_kd + offsets_kt[:, None] * stride_kt
|
||||||
k = tl.load(offsets_k, mask=offsets_kt[None, :] < cur_seq_len, other=0.0)
|
k = tl.load(offsets_k, mask=offsets_kt[:, None] < cur_seq_len, other=0.0)
|
||||||
offsets_kcachebs = tl.arange(0, BLOCK_SIZE)
|
offsets_kcachebs = tl.arange(0, BLOCK_SIZE)
|
||||||
offsets_kcache = (
|
offsets_kcache = (
|
||||||
KCache
|
KCache
|
||||||
+ offset_kvcache
|
+ offset_kvcache
|
||||||
+ offsets_dmodel[:, None] * stride_cached
|
+ offsets_dmodel[None, :] * stride_cached
|
||||||
+ offsets_kcachebs[None, :] * stride_cachebs
|
+ offsets_kcachebs[:, None] * stride_cachebs
|
||||||
)
|
)
|
||||||
tl.store(offsets_kcache, k, mask=offsets_kcachebs[None, :] < cur_seq_len - block_start_m * BLOCK_SIZE)
|
tl.store(offsets_kcache, k, mask=offsets_kcachebs[:, None] < cur_seq_len - block_start_m * BLOCK_SIZE)
|
||||||
# Copy v to corresponding cache block
|
# Copy v to corresponding cache block
|
||||||
offsets_vd = offsets_dmodel
|
offsets_vd = offsets_dmodel
|
||||||
offsets_vt = block_start_m * BLOCK_N + tl.arange(0, BLOCK_N)
|
offsets_vt = block_start_m * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||||
offsets_v = V + offset_kv + offsets_vt[:, None] * stride_vt + offsets_vd[None, :] * stride_vd
|
offsets_v = V + offset_kv + offsets_vt[None, :] * stride_vt + offsets_vd[:, None] * stride_vd
|
||||||
v = tl.load(offsets_v, mask=offsets_vt[:, None] < cur_seq_len, other=0.0)
|
v = tl.load(offsets_v, mask=offsets_vt[None, :] < cur_seq_len, other=0.0)
|
||||||
offsets_vcachebs = offsets_kcachebs # same block size range, just to notify here
|
offsets_vcachebs = offsets_kcachebs # same block size range, just to notify here
|
||||||
offsets_vcache = (
|
offsets_vcache = (
|
||||||
VCache
|
VCache
|
||||||
+ offset_kvcache
|
+ offset_kvcache
|
||||||
+ offsets_vcachebs[:, None] * stride_cachebs
|
+ offsets_vcachebs[None, :] * stride_cachebs
|
||||||
+ offsets_dmodel[None, :] * stride_cached
|
+ offsets_dmodel[:, None] * stride_cached
|
||||||
)
|
)
|
||||||
tl.store(offsets_vcache, v, mask=offsets_vcachebs[:, None] < cur_seq_len - block_start_m * BLOCK_SIZE)
|
tl.store(offsets_vcache, v, mask=offsets_vcachebs[None, :] < cur_seq_len - block_start_m * BLOCK_SIZE)
|
||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -10,8 +10,8 @@ import triton.language as tl
|
|||||||
@triton.jit
|
@triton.jit
|
||||||
def _flash_decoding_fwd_kernel(
|
def _flash_decoding_fwd_kernel(
|
||||||
Q, # [batch_size, head_num, q_len(1), head_dim]
|
Q, # [batch_size, head_num, q_len(1), head_dim]
|
||||||
KCache, # [num_blocks, num_kv_heads, head_dim, block_size]
|
KCache, # [num_blocks, num_kv_heads, block_size, head_dim]
|
||||||
VCache, # [num_blocks, num_kv_heads, head_dim, block_size]
|
VCache, # [num_blocks, num_kv_heads, block_size, head_dim]
|
||||||
block_tables, # [batch_size, max_blocks_per_sequence]
|
block_tables, # [batch_size, max_blocks_per_sequence]
|
||||||
mid_o, # [batch_size, head_num, kv_split_num, head_dim]
|
mid_o, # [batch_size, head_num, kv_split_num, head_dim]
|
||||||
mid_o_lse, # [batch_size, head_num, kv_split_num]
|
mid_o_lse, # [batch_size, head_num, kv_split_num]
|
||||||
@ -22,8 +22,8 @@ def _flash_decoding_fwd_kernel(
|
|||||||
stride_qd,
|
stride_qd,
|
||||||
stride_cacheb,
|
stride_cacheb,
|
||||||
stride_cacheh,
|
stride_cacheh,
|
||||||
stride_cached,
|
|
||||||
stride_cachebs,
|
stride_cachebs,
|
||||||
|
stride_cached,
|
||||||
stride_bts,
|
stride_bts,
|
||||||
stride_btb,
|
stride_btb,
|
||||||
stride_mid_ot,
|
stride_mid_ot,
|
||||||
@ -79,18 +79,18 @@ def _flash_decoding_fwd_kernel(
|
|||||||
|
|
||||||
K_block_ptr = tl.make_block_ptr(
|
K_block_ptr = tl.make_block_ptr(
|
||||||
base=KCache + offset_kvcache,
|
base=KCache + offset_kvcache,
|
||||||
shape=(HEAD_DIM, cur_occupied_size),
|
shape=(cur_occupied_size, HEAD_DIM),
|
||||||
strides=(stride_cached, stride_cachebs),
|
strides=(stride_cachebs, stride_cached),
|
||||||
offsets=(0, 0),
|
offsets=(0, 0),
|
||||||
block_shape=(HEAD_DIM, BLOCK_SIZE),
|
block_shape=(BLOCK_SIZE, HEAD_DIM),
|
||||||
order=(0, 1),
|
order=(0, 1),
|
||||||
)
|
)
|
||||||
V_block_ptr = tl.make_block_ptr(
|
V_block_ptr = tl.make_block_ptr(
|
||||||
base=VCache + offset_kvcache,
|
base=VCache + offset_kvcache,
|
||||||
shape=(HEAD_DIM, cur_occupied_size),
|
shape=(cur_occupied_size, HEAD_DIM),
|
||||||
strides=(stride_cached, stride_cachebs),
|
strides=(stride_cachebs, stride_cached),
|
||||||
offsets=(0, 0),
|
offsets=(0, 0),
|
||||||
block_shape=(HEAD_DIM, BLOCK_SIZE),
|
block_shape=(BLOCK_SIZE, HEAD_DIM),
|
||||||
order=(0, 1),
|
order=(0, 1),
|
||||||
)
|
)
|
||||||
k_cur_block = tl.load(K_block_ptr)
|
k_cur_block = tl.load(K_block_ptr)
|
||||||
@ -102,7 +102,7 @@ def _flash_decoding_fwd_kernel(
|
|||||||
# NOTE a trick to come across triton's requirement that values in both first and second input shapes must be >= 16,
|
# NOTE a trick to come across triton's requirement that values in both first and second input shapes must be >= 16,
|
||||||
# Multiplying two tensors with shapes [1, d] * [d, block_size] will fail.
|
# Multiplying two tensors with shapes [1, d] * [d, block_size] will fail.
|
||||||
# Refer to https://github.com/openai/triton/discussions/895
|
# Refer to https://github.com/openai/triton/discussions/895
|
||||||
S_ij += tl.sum(q[:, None] * k_cur_block, 0)
|
S_ij += tl.sum(q[None, :] * k_cur_block, 1)
|
||||||
S_ij *= sm_scale
|
S_ij *= sm_scale
|
||||||
S_ij += tl.where(block_start_kv * BLOCK_KV + tl.arange(0, BLOCK_SIZE) < cur_kv_seq_len, 0, float("-inf"))
|
S_ij += tl.where(block_start_kv * BLOCK_KV + tl.arange(0, BLOCK_SIZE) < cur_kv_seq_len, 0, float("-inf"))
|
||||||
|
|
||||||
@ -111,7 +111,7 @@ def _flash_decoding_fwd_kernel(
|
|||||||
p_ij_hat = tl.exp(S_ij)
|
p_ij_hat = tl.exp(S_ij)
|
||||||
l = tl.sum(p_ij_hat, 0)
|
l = tl.sum(p_ij_hat, 0)
|
||||||
p_ij_hat = p_ij_hat.to(v_cur_block.type.element_ty)
|
p_ij_hat = p_ij_hat.to(v_cur_block.type.element_ty)
|
||||||
acc += tl.sum(v_cur_block * p_ij_hat[None, :], 1)
|
acc += tl.sum(v_cur_block * p_ij_hat[:, None], 0)
|
||||||
acc = acc / l
|
acc = acc / l
|
||||||
|
|
||||||
offsets_mid_o = (
|
offsets_mid_o = (
|
||||||
@ -206,8 +206,8 @@ def flash_decoding_attention(
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
q (torch.Tensor): [bsz, num_heads, head_dim]
|
q (torch.Tensor): [bsz, num_heads, head_dim]
|
||||||
k_cache (torch.Tensor): [num_blocks, num_kv_heads, head_dim, block_size]
|
k_cache (torch.Tensor): [num_blocks, num_kv_heads, block_size, head_dim]
|
||||||
v_cache (torch.Tensor): [num_blocks, num_kv_heads, head_dim, block_size]
|
v_cache (torch.Tensor): [num_blocks, num_kv_heads, block_size, head_dim]
|
||||||
kv_seq_len (torch.Tensor): [batch_size]
|
kv_seq_len (torch.Tensor): [batch_size]
|
||||||
records the (kv) sequence lengths incorporating past kv sequence lengths.
|
records the (kv) sequence lengths incorporating past kv sequence lengths.
|
||||||
block_tables (torch.Tensor): [batch_size, max_blocks_per_sequence]
|
block_tables (torch.Tensor): [batch_size, max_blocks_per_sequence]
|
||||||
@ -230,13 +230,13 @@ def flash_decoding_attention(
|
|||||||
assert head_dim in {32, 64, 128, 256}
|
assert head_dim in {32, 64, 128, 256}
|
||||||
assert kv_seq_len.shape[0] == block_tables.shape[0] == bsz, (
|
assert kv_seq_len.shape[0] == block_tables.shape[0] == bsz, (
|
||||||
f"Got incompatible batch size (number of seqs):\n"
|
f"Got incompatible batch size (number of seqs):\n"
|
||||||
f" KV seq lengths bsz {kv_seq_len.shape[0]}, Block tables bsz {block_tables.shape[0]}, "
|
f" KV seq lengths bsz {kv_seq_len.size(0)}, Block tables bsz {block_tables.size(0)}, "
|
||||||
f"batch size {bsz}"
|
f"batch size {bsz}"
|
||||||
)
|
)
|
||||||
assert k_cache.size(-1) == v_cache.size(-1) == block_size, (
|
assert k_cache.size(-2) == v_cache.size(-2) == block_size, (
|
||||||
f"Got incompatible block size on kv caches:\n"
|
f"Got incompatible block size on kv caches:\n"
|
||||||
f" assigned block_size {block_size}, k_cache block_size {k_cache.size(-1)}, "
|
f" assigned block_size {block_size}, k_cache block_size {k_cache.size(-2)}, "
|
||||||
f"v_cache block_size {v_cache.size(-1)}"
|
f"v_cache block_size {v_cache.size(-2)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# NOTE BLOCK_KV could be considered as block splitting the sequence on k/v
|
# NOTE BLOCK_KV could be considered as block splitting the sequence on k/v
|
||||||
|
@ -15,8 +15,8 @@ def _copy_to_kvcache_seqlen1_kernel(
|
|||||||
stride_kd,
|
stride_kd,
|
||||||
stride_cacheb,
|
stride_cacheb,
|
||||||
stride_cacheh,
|
stride_cacheh,
|
||||||
stride_cached,
|
|
||||||
stride_cachebs,
|
stride_cachebs,
|
||||||
|
stride_cached,
|
||||||
stride_bts,
|
stride_bts,
|
||||||
stride_btb,
|
stride_btb,
|
||||||
block_size,
|
block_size,
|
||||||
@ -29,15 +29,15 @@ def _copy_to_kvcache_seqlen1_kernel(
|
|||||||
last_bt_block_idx = past_kv_seq_len // block_size
|
last_bt_block_idx = past_kv_seq_len // block_size
|
||||||
block_table_ptr = BLOCK_TABLES + cur_seq_idx * stride_bts
|
block_table_ptr = BLOCK_TABLES + cur_seq_idx * stride_bts
|
||||||
block_id = tl.load(block_table_ptr + last_bt_block_idx * stride_btb)
|
block_id = tl.load(block_table_ptr + last_bt_block_idx * stride_btb)
|
||||||
offsets_in_last_block = (past_kv_seq_len % block_size) * stride_cachebs
|
offsets_in_last_block = past_kv_seq_len % block_size
|
||||||
offsets_dmodel = tl.arange(0, HEAD_DIM)
|
offsets_dmodel = tl.arange(0, HEAD_DIM)
|
||||||
offsets_kv = cur_seq_idx * stride_kt + cur_kv_head_idx * stride_kh + offsets_dmodel * stride_kd
|
offsets_kv = cur_seq_idx * stride_kt + cur_kv_head_idx * stride_kh + offsets_dmodel * stride_kd
|
||||||
kv = tl.load(KV + offsets_kv)
|
kv = tl.load(KV + offsets_kv)
|
||||||
offsets_kvcache = (
|
offsets_kvcache = (
|
||||||
block_id * stride_cacheb
|
block_id * stride_cacheb
|
||||||
+ cur_kv_head_idx * stride_cacheh
|
+ cur_kv_head_idx * stride_cacheh
|
||||||
|
+ offsets_in_last_block * stride_cachebs
|
||||||
+ offsets_dmodel * stride_cached
|
+ offsets_dmodel * stride_cached
|
||||||
+ offsets_in_last_block
|
|
||||||
)
|
)
|
||||||
tl.store(KVCache + offsets_kvcache, kv)
|
tl.store(KVCache + offsets_kvcache, kv)
|
||||||
return
|
return
|
||||||
@ -52,23 +52,18 @@ def copy_kv_to_blocked_cache(
|
|||||||
"""
|
"""
|
||||||
Copy keys or values to the blocked key/value cache during decoding stage.
|
Copy keys or values to the blocked key/value cache during decoding stage.
|
||||||
|
|
||||||
Parameters:
|
Args:
|
||||||
- k (torch.Tensor): [bsz, 1, num_kv_heads, head_dim]/[bsz, num_kv_heads, head_dim] - Keys or values during decoding with seq len 1.
|
k (torch.Tensor): [bsz, 1, num_kv_heads, head_dim]/[bsz, num_kv_heads, head_dim] - Keys or values during decoding with seq len 1.
|
||||||
- k_cache (torch.Tensor): [num_blocks, num_kv_heads, head_dim, block_size] - Blocked key or value cache.
|
k_cache (torch.Tensor): [num_blocks, num_kv_heads, block_size, head_dim] - Blocked key or value cache.
|
||||||
- kv_lengths (torch.Tensor): [bsz] - Past key/value sequence lengths plus current sequence length for each sequence.
|
kv_lengths (torch.Tensor): [bsz] - Past key/value sequence lengths plus current sequence length for each sequence.
|
||||||
- block_tables (torch.Tensor): [bsz, max_blocks_per_sequence] - Block tables for each sequence.
|
block_tables (torch.Tensor): [bsz, max_blocks_per_sequence] - Block tables for each sequence.
|
||||||
"""
|
"""
|
||||||
assert k.size(-1) == k_cache.size(-2), "Incompatible head dim"
|
assert k.size(-1) == k_cache.size(-1), "Incompatible head dim"
|
||||||
assert k.dtype == k_cache.dtype, "Expected consistent dtype for tensor and cache."
|
assert k.dtype == k_cache.dtype, "Expected consistent dtype for tensor and cache."
|
||||||
if k.dim() == 4:
|
|
||||||
assert k.size(1) == 1, "Unsupported kv seq len (supposed to be used for decoding stage)"
|
k = k.squeeze(1) if k.dim() == 4 else k
|
||||||
bsz, _, num_kv_heads, head_dim = k.shape
|
assert k.dim() == 3, f"Incompatible k dim {k.dim()}"
|
||||||
# [bsz, 1, num_kv_heads, head_dim] -> [bsz, num_kv_heads, head_dim]
|
bsz, num_kv_heads, head_dim = k.shape
|
||||||
k = k.squeeze(dim=1)
|
|
||||||
elif k.dim() == 3:
|
|
||||||
bsz, num_kv_heads, head_dim = k.shape
|
|
||||||
else:
|
|
||||||
raise ValueError(f"The key dim should be 3 or 4, but got {k.dim()}.")
|
|
||||||
|
|
||||||
assert kv_lengths.shape[0] == block_tables.shape[0] == bsz, (
|
assert kv_lengths.shape[0] == block_tables.shape[0] == bsz, (
|
||||||
f"Got incompatible batch size (number of seqs):\n"
|
f"Got incompatible batch size (number of seqs):\n"
|
||||||
@ -77,7 +72,7 @@ def copy_kv_to_blocked_cache(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Modify if the shape of kv cahce is changed.
|
# Modify if the shape of kv cahce is changed.
|
||||||
block_size = k_cache.size(-1)
|
block_size = k_cache.size(-2)
|
||||||
|
|
||||||
num_warps = 8 if head_dim > 128 else 4
|
num_warps = 8 if head_dim > 128 else 4
|
||||||
|
|
||||||
|
@ -93,7 +93,7 @@ def check_cache_manager(test_config):
|
|||||||
assert len(cache_manager._cache_blocks) == num_blocks
|
assert len(cache_manager._cache_blocks) == num_blocks
|
||||||
key_caches = cache_manager._kv_caches[0] # key caches for all the blocks in all the layers
|
key_caches = cache_manager._kv_caches[0] # key caches for all the blocks in all the layers
|
||||||
assert len(key_caches) == num_layers
|
assert len(key_caches) == num_layers
|
||||||
expected_kv_shape = (num_blocks, num_attention_heads, head_size, block_size)
|
expected_kv_shape = (num_blocks, num_attention_heads, block_size, head_size)
|
||||||
assert key_caches[0].shape == expected_kv_shape
|
assert key_caches[0].shape == expected_kv_shape
|
||||||
k_cache_block0, v_cache_block0 = cache_manager.get_physical_cache(0, 0)
|
k_cache_block0, v_cache_block0 = cache_manager.get_physical_cache(0, 0)
|
||||||
expected_kv_block_shape = expected_kv_shape[1:]
|
expected_kv_block_shape = expected_kv_shape[1:]
|
||||||
|
@ -1,20 +1,17 @@
|
|||||||
import pytest
|
|
||||||
import torch
|
import torch
|
||||||
from transformers.cache_utils import DynamicCache
|
from transformers.cache_utils import DynamicCache
|
||||||
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
|
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
|
||||||
from transformers.models.llama.configuration_llama import LlamaConfig
|
from transformers.models.llama.configuration_llama import LlamaConfig
|
||||||
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb
|
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb
|
||||||
|
|
||||||
import colossalai
|
|
||||||
from colossalai.inference.modeling.layers.attention import PagedAttention, convert_kvcache, copy_to_cache
|
from colossalai.inference.modeling.layers.attention import PagedAttention, convert_kvcache, copy_to_cache
|
||||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
|
||||||
|
|
||||||
|
|
||||||
def test_copy_to_cache():
|
def test_copy_to_cache():
|
||||||
key = torch.ones((2, 11, 3, 3))
|
key = torch.ones((2, 11, 3, 3))
|
||||||
key[0, 9, :, :] = 0
|
key[0, 9, :, :] = 0
|
||||||
key[1, -2:, :, :] = 0
|
key[1, -2:, :, :] = 0
|
||||||
cache = torch.zeros(8, 3, 3, 8)
|
cache = torch.zeros(8, 3, 8, 3)
|
||||||
block_tables = torch.tensor([[0, 1], [2, 3]])
|
block_tables = torch.tensor([[0, 1], [2, 3]])
|
||||||
lengths = torch.tensor([9, 8])
|
lengths = torch.tensor([9, 8])
|
||||||
cache = copy_to_cache(key, cache=cache, lengths=lengths, block_tables=block_tables, type="prefill")
|
cache = copy_to_cache(key, cache=cache, lengths=lengths, block_tables=block_tables, type="prefill")
|
||||||
@ -28,7 +25,7 @@ def test_copy_to_cache():
|
|||||||
|
|
||||||
|
|
||||||
def test_convert_kvcache():
|
def test_convert_kvcache():
|
||||||
cache = torch.ones(8, 3, 3, 8)
|
cache = torch.ones(8, 3, 8, 3)
|
||||||
key = torch.ones(2, 1, 3, 3) + 1
|
key = torch.ones(2, 1, 3, 3) + 1
|
||||||
lengths = torch.tensor([10, 9])
|
lengths = torch.tensor([10, 9])
|
||||||
block_tables = torch.tensor([[0, 1], [2, 3]])
|
block_tables = torch.tensor([[0, 1], [2, 3]])
|
||||||
@ -43,8 +40,8 @@ def test_context_attention():
|
|||||||
"""
|
"""
|
||||||
attn = PagedAttention()
|
attn = PagedAttention()
|
||||||
q = k = v = torch.randn(8, 4, 4)
|
q = k = v = torch.randn(8, 4, 4)
|
||||||
k_cache = torch.empty(8, 4, 4, 8)
|
k_cache = torch.empty(8, 4, 8, 4)
|
||||||
v_cache = torch.empty(8, 4, 4, 8)
|
v_cache = torch.empty(8, 4, 8, 4)
|
||||||
context_lengths = torch.tensor(
|
context_lengths = torch.tensor(
|
||||||
[
|
[
|
||||||
8,
|
8,
|
||||||
@ -136,23 +133,8 @@ def test_decoding_attention():
|
|||||||
assert torch.allclose(pad_attn_output, attn_output, atol=1e-3, rtol=1e-2)
|
assert torch.allclose(pad_attn_output, attn_output, atol=1e-3, rtol=1e-2)
|
||||||
|
|
||||||
|
|
||||||
def check_attention_layer():
|
if __name__ == "__main__":
|
||||||
test_copy_to_cache()
|
test_copy_to_cache()
|
||||||
test_convert_kvcache()
|
test_convert_kvcache()
|
||||||
test_context_attention()
|
test_context_attention()
|
||||||
test_decoding_attention()
|
test_decoding_attention()
|
||||||
|
|
||||||
|
|
||||||
def run_dist(rank, world_size, port):
|
|
||||||
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost")
|
|
||||||
check_attention_layer()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.dist
|
|
||||||
@rerun_if_address_is_in_use()
|
|
||||||
def test_attention_layer():
|
|
||||||
spawn(run_dist, 1)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
test_attention_layer()
|
|
||||||
|
@ -106,6 +106,40 @@ def mock_alloc_block_table_and_kvcache(
|
|||||||
return block_tables
|
return block_tables
|
||||||
|
|
||||||
|
|
||||||
|
def mock_alloc_block_table_and_kvcache_v2(
|
||||||
|
k: torch.Tensor,
|
||||||
|
v: torch.Tensor,
|
||||||
|
k_cache: torch.Tensor,
|
||||||
|
v_cache: torch.Tensor,
|
||||||
|
context_lengths: torch.Tensor,
|
||||||
|
num_seqs: int,
|
||||||
|
max_num_blocks_per_seq: int,
|
||||||
|
block_size: int,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Allocate block tables based on provided context lengths; and copy KV to blocked KV Cache."""
|
||||||
|
block_id = 0
|
||||||
|
block_tables = torch.full(size=(num_seqs, max_num_blocks_per_seq), fill_value=-1, dtype=torch.int32)
|
||||||
|
num_tokens_processed = 0
|
||||||
|
for i, seq_len in enumerate(context_lengths.tolist()):
|
||||||
|
right_bound = (seq_len + block_size - 1) // block_size # open bound
|
||||||
|
block_tables[i, :right_bound] = torch.arange(block_id, block_id + right_bound, dtype=torch.int32)
|
||||||
|
# Manually fill kv caches by copying from k and v
|
||||||
|
for i in range(right_bound):
|
||||||
|
if i == right_bound - 1:
|
||||||
|
allocated_locs = seq_len % block_size or block_size
|
||||||
|
else:
|
||||||
|
allocated_locs = block_size
|
||||||
|
k_block = k[num_tokens_processed : num_tokens_processed + allocated_locs, :, :].permute(1, 0, 2)
|
||||||
|
v_block = v[num_tokens_processed : num_tokens_processed + allocated_locs, :, :].permute(1, 0, 2)
|
||||||
|
k_cache[block_id, :, :allocated_locs, :] = k_block
|
||||||
|
v_cache[block_id, :, :allocated_locs, :] = v_block
|
||||||
|
|
||||||
|
num_tokens_processed += allocated_locs
|
||||||
|
block_id += 1
|
||||||
|
|
||||||
|
return block_tables
|
||||||
|
|
||||||
|
|
||||||
def mock_alloc_single_token(block_tables: torch.Tensor, context_lengths: torch.Tensor, block_size: int) -> None:
|
def mock_alloc_single_token(block_tables: torch.Tensor, context_lengths: torch.Tensor, block_size: int) -> None:
|
||||||
# Allocate 1 token on the block table for each seqs in block tables.
|
# Allocate 1 token on the block table for each seqs in block tables.
|
||||||
# It won't change provided context_lengths.
|
# It won't change provided context_lengths.
|
||||||
@ -146,6 +180,22 @@ def generate_caches_and_block_tables(
|
|||||||
return k_cache, v_cache, block_tables
|
return k_cache, v_cache, block_tables
|
||||||
|
|
||||||
|
|
||||||
|
def generate_caches_and_block_tables_v2(
|
||||||
|
k_unpad, v_unpad, kv_lengths, bsz, max_num_blocks_per_seq, block_size, dtype=torch.float16, device="cuda"
|
||||||
|
) -> Tuple[torch.Tensor, ...]:
|
||||||
|
# Mock generation of k/v blocked caches and block tables from providied kv unpad and seq lengths
|
||||||
|
# k_unpad/v_unpad [num_total_tokens, num_kv_heads, head_dim]
|
||||||
|
_, num_kv_heads, head_dim = k_unpad.shape
|
||||||
|
cache_shape = (bsz * max_num_blocks_per_seq, num_kv_heads, block_size, head_dim)
|
||||||
|
k_cache = torch.zeros(size=cache_shape, dtype=dtype, device=device)
|
||||||
|
v_cache = torch.zeros(size=cache_shape, dtype=dtype, device=device)
|
||||||
|
# Mock allocation on block tables as well as blocked kv caches
|
||||||
|
block_tables = mock_alloc_block_table_and_kvcache_v2(
|
||||||
|
k_unpad, v_unpad, k_cache, v_cache, kv_lengths, bsz, max_num_blocks_per_seq, block_size
|
||||||
|
)
|
||||||
|
return k_cache, v_cache, block_tables
|
||||||
|
|
||||||
|
|
||||||
def convert_kv_unpad_to_padded(
|
def convert_kv_unpad_to_padded(
|
||||||
k_unpad: torch.Tensor, kv_seq_lengths: torch.Tensor, bsz: int, max_seq_len: int
|
k_unpad: torch.Tensor, kv_seq_lengths: torch.Tensor, bsz: int, max_seq_len: int
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
@ -6,7 +6,7 @@ from transformers.modeling_attn_mask_utils import AttentionMaskConverter
|
|||||||
from colossalai.inference.modeling.layers.attention import PagedAttention
|
from colossalai.inference.modeling.layers.attention import PagedAttention
|
||||||
from colossalai.kernel.triton import context_attention_unpadded
|
from colossalai.kernel.triton import context_attention_unpadded
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
from tests.test_infer_ops.triton.kernel_utils import generate_caches_and_block_tables, torch_attn_ref
|
from tests.test_infer_ops.triton.kernel_utils import generate_caches_and_block_tables_v2, torch_attn_ref
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import triton # noqa
|
import triton # noqa
|
||||||
@ -93,7 +93,7 @@ def test_context_attention(
|
|||||||
q_unpad, k_unpad, v_unpad = torch.split(qkv_unpad, [num_attn_heads, num_kv_heads, num_kv_heads], dim=-2)
|
q_unpad, k_unpad, v_unpad = torch.split(qkv_unpad, [num_attn_heads, num_kv_heads, num_kv_heads], dim=-2)
|
||||||
q_unpad = q_unpad.contiguous()
|
q_unpad = q_unpad.contiguous()
|
||||||
|
|
||||||
k_cache_ref, v_cache_ref, block_tables = generate_caches_and_block_tables(
|
k_cache_ref, v_cache_ref, block_tables = generate_caches_and_block_tables_v2(
|
||||||
k_unpad, v_unpad, context_lengths, bsz, max_num_blocks_per_seq, block_size, dtype, device
|
k_unpad, v_unpad, context_lengths, bsz, max_num_blocks_per_seq, block_size, dtype, device
|
||||||
)
|
)
|
||||||
block_tables = block_tables.to(device=device)
|
block_tables = block_tables.to(device=device)
|
||||||
@ -148,7 +148,6 @@ def bench_kernel(
|
|||||||
|
|
||||||
num_kv_heads = num_attn_heads // kv_group_num
|
num_kv_heads = num_attn_heads // kv_group_num
|
||||||
assert isinstance(num_kv_heads, int) and num_kv_heads > 0, "Invalid number of kv heads."
|
assert isinstance(num_kv_heads, int) and num_kv_heads > 0, "Invalid number of kv heads."
|
||||||
block_size * max_num_blocks_per_seq
|
|
||||||
dtype = torch.float16
|
dtype = torch.float16
|
||||||
device = get_current_device()
|
device = get_current_device()
|
||||||
|
|
||||||
@ -162,7 +161,7 @@ def bench_kernel(
|
|||||||
qkv_unpad = torch.empty(size=qkv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5)
|
qkv_unpad = torch.empty(size=qkv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5)
|
||||||
q_unpad, k_unpad, v_unpad = torch.split(qkv_unpad, [num_attn_heads, num_kv_heads, num_kv_heads], dim=-2)
|
q_unpad, k_unpad, v_unpad = torch.split(qkv_unpad, [num_attn_heads, num_kv_heads, num_kv_heads], dim=-2)
|
||||||
q_unpad = q_unpad.contiguous()
|
q_unpad = q_unpad.contiguous()
|
||||||
k_cache_ref, v_cache_ref, block_tables = generate_caches_and_block_tables(
|
k_cache_ref, v_cache_ref, block_tables = generate_caches_and_block_tables_v2(
|
||||||
k_unpad, v_unpad, context_lengths, bsz, max_num_blocks_per_seq, block_size, dtype, device
|
k_unpad, v_unpad, context_lengths, bsz, max_num_blocks_per_seq, block_size, dtype, device
|
||||||
)
|
)
|
||||||
block_tables = block_tables.to(device=device)
|
block_tables = block_tables.to(device=device)
|
||||||
|
@ -6,7 +6,7 @@ from colossalai.kernel.triton import flash_decoding_attention
|
|||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
from tests.test_infer_ops.triton.kernel_utils import (
|
from tests.test_infer_ops.triton.kernel_utils import (
|
||||||
convert_kv_unpad_to_padded,
|
convert_kv_unpad_to_padded,
|
||||||
generate_caches_and_block_tables,
|
generate_caches_and_block_tables_v2,
|
||||||
prepare_padding_mask,
|
prepare_padding_mask,
|
||||||
torch_attn_ref,
|
torch_attn_ref,
|
||||||
)
|
)
|
||||||
@ -38,6 +38,9 @@ def prepare_data(
|
|||||||
):
|
):
|
||||||
# Use the provided maximum sequence length for each sequence when testing with teh same context length,
|
# Use the provided maximum sequence length for each sequence when testing with teh same context length,
|
||||||
# otherwise generate random context lengths.
|
# otherwise generate random context lengths.
|
||||||
|
# returns
|
||||||
|
# q [bsz, num_attn_heads, q_len, head_dim]
|
||||||
|
# k_unpad/v_unpad [num_tokens, num_kv_heads, head_dim]
|
||||||
kv_lengths = (
|
kv_lengths = (
|
||||||
torch.tensor([max_kv_seq_len for _ in range(bsz)], dtype=torch.int32, device=device)
|
torch.tensor([max_kv_seq_len for _ in range(bsz)], dtype=torch.int32, device=device)
|
||||||
if same_context_len
|
if same_context_len
|
||||||
@ -83,7 +86,7 @@ def test_flash_decoding(
|
|||||||
q, k_unpad, v_unpad, kv_seq_lengths = prepare_data(
|
q, k_unpad, v_unpad, kv_seq_lengths = prepare_data(
|
||||||
bsz, num_attn_heads, num_kv_heads, HEAD_DIM, same_context_len, Q_LEN, max_seq_len, dtype, device
|
bsz, num_attn_heads, num_kv_heads, HEAD_DIM, same_context_len, Q_LEN, max_seq_len, dtype, device
|
||||||
)
|
)
|
||||||
k_cache, v_cache, block_tables = generate_caches_and_block_tables(
|
k_cache, v_cache, block_tables = generate_caches_and_block_tables_v2(
|
||||||
k_unpad, v_unpad, kv_seq_lengths, bsz, max_num_blocks_per_seq, block_size, dtype, device
|
k_unpad, v_unpad, kv_seq_lengths, bsz, max_num_blocks_per_seq, block_size, dtype, device
|
||||||
)
|
)
|
||||||
block_tables = block_tables.to(device=device)
|
block_tables = block_tables.to(device=device)
|
||||||
@ -180,7 +183,7 @@ def bench_kernel(
|
|||||||
)
|
)
|
||||||
ms, min_ms, max_ms = triton.testing.do_bench(fn, warmup=WARM_UPS, rep=REPS, quantiles=quantiles)
|
ms, min_ms, max_ms = triton.testing.do_bench(fn, warmup=WARM_UPS, rep=REPS, quantiles=quantiles)
|
||||||
if provider == "triton":
|
if provider == "triton":
|
||||||
k_cache, v_cache, block_tables = generate_caches_and_block_tables(
|
k_cache, v_cache, block_tables = generate_caches_and_block_tables_v2(
|
||||||
k_unpad, v_unpad, kv_lengths, bsz, max_num_blocks_per_seq, block_size, dtype, device
|
k_unpad, v_unpad, kv_lengths, bsz, max_num_blocks_per_seq, block_size, dtype, device
|
||||||
)
|
)
|
||||||
block_tables = block_tables.to(device=device)
|
block_tables = block_tables.to(device=device)
|
||||||
|
@ -5,7 +5,7 @@ from packaging import version
|
|||||||
from colossalai.inference.modeling.layers.attention import copy_to_cache
|
from colossalai.inference.modeling.layers.attention import copy_to_cache
|
||||||
from colossalai.kernel.triton import copy_kv_to_blocked_cache
|
from colossalai.kernel.triton import copy_kv_to_blocked_cache
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
from tests.test_infer_ops.triton.kernel_utils import mock_alloc_block_table_and_kvcache, mock_alloc_single_token
|
from tests.test_infer_ops.triton.kernel_utils import generate_caches_and_block_tables_v2, mock_alloc_single_token
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import triton # noqa
|
import triton # noqa
|
||||||
@ -17,6 +17,8 @@ except ImportError:
|
|||||||
|
|
||||||
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4")
|
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4")
|
||||||
|
|
||||||
|
HEAD_DIM = 128
|
||||||
|
|
||||||
|
|
||||||
def prepare_data(
|
def prepare_data(
|
||||||
bsz,
|
bsz,
|
||||||
@ -29,31 +31,27 @@ def prepare_data(
|
|||||||
device,
|
device,
|
||||||
dtype=torch.float16,
|
dtype=torch.float16,
|
||||||
):
|
):
|
||||||
if same_context_len:
|
# past_kv_seq_lengths in this test records the previous kv seq len
|
||||||
# past_kv_seq_lengths in this test records the previous kv seq len
|
# (not incorporating the current input whose seq len is 1)
|
||||||
# (not incorporating the current input whose seq len is 1)
|
past_kv_seq_lengths = (
|
||||||
past_kv_seq_lengths = torch.tensor([max_seq_len - 1 for _ in range(bsz)], dtype=torch.int32, device=device)
|
torch.tensor([max_seq_len - 1 for _ in range(bsz)], dtype=torch.int32, device=device)
|
||||||
else:
|
if same_context_len
|
||||||
past_kv_seq_lengths = torch.randint(low=1, high=max_seq_len - 1, size=(bsz,), dtype=torch.int32, device=device)
|
else torch.randint(low=1, high=max_seq_len - 1, size=(bsz,), dtype=torch.int32, device=device)
|
||||||
|
)
|
||||||
num_tokens = torch.sum(past_kv_seq_lengths).item()
|
num_tokens = torch.sum(past_kv_seq_lengths).item()
|
||||||
|
|
||||||
kv_size = (num_tokens, 2 * num_kv_heads, head_dim)
|
kv_size = (num_tokens, 2 * num_kv_heads, head_dim)
|
||||||
kv = torch.empty(size=kv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5)
|
kv_unpad = torch.empty(size=kv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5)
|
||||||
k, v = torch.split(kv, [num_kv_heads, num_kv_heads], dim=-2)
|
k_unpad, v_unpad = torch.split(kv_unpad, [num_kv_heads, num_kv_heads], dim=-2)
|
||||||
|
|
||||||
cache_shape = (bsz * max_num_blocks_per_seq, num_kv_heads, head_dim, block_size)
|
k_cache, _, block_tables = generate_caches_and_block_tables_v2(
|
||||||
k_cache = torch.zeros(size=cache_shape, dtype=dtype, device=device)
|
k_unpad, v_unpad, past_kv_seq_lengths, bsz, max_num_blocks_per_seq, block_size, dtype=dtype, device=device
|
||||||
v_cache = torch.zeros(size=cache_shape, dtype=dtype, device=device)
|
|
||||||
# Mock allocation on block tables as well as blocked kv caches
|
|
||||||
block_tables = mock_alloc_block_table_and_kvcache(
|
|
||||||
k, v, k_cache, v_cache, past_kv_seq_lengths, bsz, max_num_blocks_per_seq, block_size
|
|
||||||
)
|
)
|
||||||
block_tables = block_tables.to(device=device)
|
block_tables = block_tables.to(device=device)
|
||||||
|
|
||||||
new_k = torch.randn((bsz, 1, num_kv_heads, head_dim), dtype=dtype, device=device)
|
new_k = torch.randn((bsz, 1, num_kv_heads, head_dim), dtype=dtype, device=device)
|
||||||
# mock allocating blocks for the new k/v and update block tables
|
# mock allocating blocks for the new k/v and update block tables
|
||||||
mock_alloc_single_token(block_tables, past_kv_seq_lengths, block_size)
|
mock_alloc_single_token(block_tables, past_kv_seq_lengths, block_size)
|
||||||
|
|
||||||
# kv seq len = past kv seq len + seq len (1 during decoding stage)
|
# kv seq len = past kv seq len + seq len (1 during decoding stage)
|
||||||
kv_seq_lengths = past_kv_seq_lengths + 1
|
kv_seq_lengths = past_kv_seq_lengths + 1
|
||||||
|
|
||||||
@ -78,7 +76,6 @@ def test_copy_kv_to_caches(
|
|||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
torch.cuda.reset_peak_memory_stats()
|
torch.cuda.reset_peak_memory_stats()
|
||||||
|
|
||||||
head_dim = 128
|
|
||||||
max_seq_len = block_size * max_num_blocks_per_seq
|
max_seq_len = block_size * max_num_blocks_per_seq
|
||||||
dtype = torch.float16
|
dtype = torch.float16
|
||||||
device = get_current_device()
|
device = get_current_device()
|
||||||
@ -86,7 +83,7 @@ def test_copy_kv_to_caches(
|
|||||||
new_k, k_cache, kv_seq_lengths, block_tables = prepare_data(
|
new_k, k_cache, kv_seq_lengths, block_tables = prepare_data(
|
||||||
bsz,
|
bsz,
|
||||||
num_kv_heads,
|
num_kv_heads,
|
||||||
head_dim,
|
HEAD_DIM,
|
||||||
block_size,
|
block_size,
|
||||||
max_num_blocks_per_seq,
|
max_num_blocks_per_seq,
|
||||||
same_context_len,
|
same_context_len,
|
||||||
@ -94,20 +91,28 @@ def test_copy_kv_to_caches(
|
|||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
)
|
)
|
||||||
|
# k_cache_torch = k_cache.clone().detach()
|
||||||
|
# copy_to_cache(new_k, k_cache_torch, lengths=kv_seq_lengths, block_tables=block_tables, type="decoding")
|
||||||
copy_kv_to_blocked_cache(new_k, k_cache, kv_seq_lengths, block_tables)
|
copy_kv_to_blocked_cache(new_k, k_cache, kv_seq_lengths, block_tables)
|
||||||
|
|
||||||
for seq_i in range(bsz):
|
past_kv_seq_len = kv_seq_lengths - 1
|
||||||
ki = new_k[seq_i]
|
target_block_ids = block_tables[range(0, block_tables.size(0)), past_kv_seq_len // block_size]
|
||||||
ki = ki.squeeze()
|
offsets_in_block = past_kv_seq_len % block_size
|
||||||
past_kv_seq_len = kv_seq_lengths[seq_i] - 1
|
target = k_cache[target_block_ids, :, offsets_in_block, :]
|
||||||
target_block_id = block_tables[seq_i, past_kv_seq_len // block_size]
|
source = new_k.squeeze()
|
||||||
offsets_in_block = past_kv_seq_len % block_size
|
|
||||||
target = k_cache[target_block_id, :, :, offsets_in_block]
|
assert target.shape == source.shape
|
||||||
orig = new_k[seq_i].squeeze(dim=0)
|
assert torch.equal(target, source)
|
||||||
assert torch.equal(orig, target)
|
# target_torch = k_cache_copy[target_block_ids, :, offsets_in_block, :]
|
||||||
|
# assert target_torch.shape == source.shape
|
||||||
|
# assert torch.equal(target_torch, source)
|
||||||
|
|
||||||
|
|
||||||
BATCH = 16
|
BATCH = 16
|
||||||
|
BLOCK_SIZE = 32
|
||||||
|
SAME_LEN = True
|
||||||
|
WARM_UPS = 10
|
||||||
|
REPS = 100
|
||||||
configs = [
|
configs = [
|
||||||
triton.testing.Benchmark(
|
triton.testing.Benchmark(
|
||||||
x_names=["KV_SEQ_LEN"],
|
x_names=["KV_SEQ_LEN"],
|
||||||
@ -133,10 +138,6 @@ def benchmark_kvcache_copy(
|
|||||||
num_kv_heads: int,
|
num_kv_heads: int,
|
||||||
same_context_len: bool,
|
same_context_len: bool,
|
||||||
):
|
):
|
||||||
warmup = 10
|
|
||||||
rep = 100
|
|
||||||
|
|
||||||
head_dim = 128
|
|
||||||
dtype = torch.float16
|
dtype = torch.float16
|
||||||
device = get_current_device()
|
device = get_current_device()
|
||||||
|
|
||||||
@ -145,7 +146,7 @@ def benchmark_kvcache_copy(
|
|||||||
new_k, k_cache, context_lengths, block_tables = prepare_data(
|
new_k, k_cache, context_lengths, block_tables = prepare_data(
|
||||||
bsz,
|
bsz,
|
||||||
num_kv_heads,
|
num_kv_heads,
|
||||||
head_dim,
|
HEAD_DIM,
|
||||||
block_size,
|
block_size,
|
||||||
max_seq_len // block_size,
|
max_seq_len // block_size,
|
||||||
same_context_len,
|
same_context_len,
|
||||||
@ -154,15 +155,14 @@ def benchmark_kvcache_copy(
|
|||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
quantiles = [0.5, 0.2, 0.8]
|
||||||
if provider == "torch_copy_func":
|
if provider == "torch_copy_func":
|
||||||
fn = lambda: copy_to_cache(new_k, k_cache, lengths=context_lengths, block_tables=block_tables, type="decoding")
|
fn = lambda: copy_to_cache(new_k, k_cache, lengths=context_lengths, block_tables=block_tables, type="decoding")
|
||||||
elif provider == "triton_copy_func":
|
if provider == "triton_copy_func":
|
||||||
fn = lambda: copy_kv_to_blocked_cache(new_k, k_cache, context_lengths, block_tables)
|
fn = lambda: copy_kv_to_blocked_cache(new_k, k_cache, context_lengths, block_tables)
|
||||||
else:
|
|
||||||
raise ValueError("Undefined provider.")
|
|
||||||
|
|
||||||
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
|
ms, min_ms, max_ms = triton.testing.do_bench(fn, warmup=WARM_UPS, rep=REPS, quantiles=quantiles)
|
||||||
return ms
|
return ms, min_ms, max_ms
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -3,7 +3,6 @@ import torch
|
|||||||
import triton
|
import triton
|
||||||
from packaging import version
|
from packaging import version
|
||||||
from transformers.models.llama.modeling_llama import LlamaRMSNorm
|
from transformers.models.llama.modeling_llama import LlamaRMSNorm
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
|
||||||
|
|
||||||
from colossalai.kernel.triton import rms_layernorm
|
from colossalai.kernel.triton import rms_layernorm
|
||||||
from colossalai.testing.utils import parameterize
|
from colossalai.testing.utils import parameterize
|
||||||
@ -36,7 +35,8 @@ def test_layer_norm(M, N):
|
|||||||
y_triton = rms_layernorm(x, weight, eps=eps)
|
y_triton = rms_layernorm(x, weight, eps=eps)
|
||||||
y_llama = rms_norm.forward(x).to(dtype)
|
y_llama = rms_norm.forward(x).to(dtype)
|
||||||
|
|
||||||
assert torch.allclose(y_triton, y_llama, atol=1e-5, rtol=1e-5)
|
assert y_triton.shape == y_llama.shape
|
||||||
|
assert torch.allclose(y_triton, y_llama, atol=1e-5, rtol=1e-3)
|
||||||
|
|
||||||
|
|
||||||
# Triton benchmark plot attributions
|
# Triton benchmark plot attributions
|
||||||
@ -45,8 +45,8 @@ configs = [
|
|||||||
x_names=["SEQUENCE_TOTAL"],
|
x_names=["SEQUENCE_TOTAL"],
|
||||||
x_vals=[i for i in range(128, 1025, 128)],
|
x_vals=[i for i in range(128, 1025, 128)],
|
||||||
line_arg="provider",
|
line_arg="provider",
|
||||||
line_vals=["vllm_rms_layernorm", "triton_rms_layernorm"],
|
line_vals=["torch_rms_layernorm", "triton_rms_layernorm"],
|
||||||
line_names=["vllm_rms_layernorm", "triton_rms_layernorm"],
|
line_names=["torch_rms_layernorm", "triton_rms_layernorm"],
|
||||||
styles=[("red", "-"), ("blue", "-")],
|
styles=[("red", "-"), ("blue", "-")],
|
||||||
ylabel="ms",
|
ylabel="ms",
|
||||||
plot_name=f"RMSNorm benchmarking results",
|
plot_name=f"RMSNorm benchmarking results",
|
||||||
@ -69,10 +69,10 @@ def benchmark_rms_layernorm(
|
|||||||
x_shape = (SEQUENCE_TOTAL, HIDDEN_SIZE)
|
x_shape = (SEQUENCE_TOTAL, HIDDEN_SIZE)
|
||||||
w_shape = (x_shape[-1],)
|
w_shape = (x_shape[-1],)
|
||||||
weight = torch.ones(w_shape, dtype=dtype, device="cuda")
|
weight = torch.ones(w_shape, dtype=dtype, device="cuda")
|
||||||
vllm_norm = RMSNorm(hidden_size=HIDDEN_SIZE).to(dtype=dtype, device="cuda")
|
torch_norm = LlamaRMSNorm(hidden_size=HIDDEN_SIZE).to(dtype=dtype, device="cuda")
|
||||||
x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda")
|
x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda")
|
||||||
if provider == "vllm_rms_layernorm":
|
if provider == "torch_rms_layernorm":
|
||||||
fn = lambda: vllm_norm(x)
|
fn = lambda: torch_norm(x)
|
||||||
elif provider == "triton_rms_layernorm":
|
elif provider == "triton_rms_layernorm":
|
||||||
fn = lambda: rms_layernorm(x, weight, eps=eps)
|
fn = lambda: rms_layernorm(x, weight, eps=eps)
|
||||||
else:
|
else:
|
||||||
|
Loading…
Reference in New Issue
Block a user