[Infer] Optimize Blocked KVCache And Kernels Using It (#5325)

* revise shape of kvcache (context attn kernel)

* revise shape of kvcache (flash decoding kernel)

* revise shape of kvcache (kvcache copy) and attn func

* init of kvcache in kvcache manager

* revise llama modeling

* revise block size retrieval

* use torch for rms_norm benchmarking

* revise block size retrieval
This commit is contained in:
Yuanheng Zhao 2024-01-30 16:06:09 +08:00 committed by GitHub
parent e8f0642f28
commit 5f98a9d68a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 171 additions and 145 deletions

View File

@ -79,10 +79,10 @@ class KVCacheManager:
self.num_blocks = self.max_blocks_per_sequence * self.max_batch_size * self.beam_width self.num_blocks = self.max_blocks_per_sequence * self.max_batch_size * self.beam_width
# Physical cache allocation # Physical cache allocation
alloc_shape = (self.num_blocks, self.head_num, self.block_size, self.head_size)
if verbose: if verbose:
alloc_shape = (self.num_blocks, self.head_num, self.head_size, self.block_size)
self.logger.info(f"Allocating KV cache with shape: {alloc_shape} consisting of {self.num_blocks} blocks.") self.logger.info(f"Allocating KV cache with shape: {alloc_shape} consisting of {self.num_blocks} blocks.")
self._kv_caches = self._init_device_caches() self._kv_caches = self._init_device_caches(alloc_shape)
self.total_physical_cache_size_in_bytes = ( self.total_physical_cache_size_in_bytes = (
self.elem_size_in_bytes self.elem_size_in_bytes
* self.num_layers * self.num_layers
@ -297,15 +297,12 @@ class KVCacheManager:
blocks.append(cache_block) blocks.append(cache_block)
return blocks return blocks
def _init_device_caches(self) -> Tuple[torch.Tensor, torch.Tensor]: def _init_device_caches(self, alloc_shape: Tuple[int, ...]) -> Tuple[torch.Tensor, torch.Tensor]:
"""Initialize the physical cache on the device. """Initialize the physical cache on the device.
For each layer of the model, we allocate two tensors for key and value respectively, For each layer of the model, we allocate two tensors for key and value respectively,
with shape of [num_blocks, num_kv_heads, head_size, block_size] with shape of [num_blocks, num_kv_heads, block_size, head_size]
""" """
alloc_shape = (self.num_blocks, self.head_num, self.head_size, self.block_size)
# TODO: Explore the performance when using difference shapes with kernel-related optimizations
# e.g. [num_blocks, num_kv_heads // x, head_size, block_size, x]
k_cache: List[torch.Tensor] = [] k_cache: List[torch.Tensor] = []
v_cache: List[torch.Tensor] = [] v_cache: List[torch.Tensor] = []
for _ in range(self.num_layers): for _ in range(self.num_layers):

View File

@ -16,7 +16,7 @@ def copy_to_cache(source, cache, lengths, block_tables, type: str = "prefill"):
lengths: key/value lengths lengths: key/value lengths
block_tables block_tables
""" """
num_blocks, num_heads, head_size, block_size = cache.shape num_blocks, num_heads, block_size, head_size = cache.shape
bsz, max_blocks_per_seq = block_tables.shape bsz, max_blocks_per_seq = block_tables.shape
needed_blocks = (lengths + block_size - 1) // block_size needed_blocks = (lengths + block_size - 1) // block_size
@ -26,17 +26,17 @@ def copy_to_cache(source, cache, lengths, block_tables, type: str = "prefill"):
block_num = needed_blocks[i] block_num = needed_blocks[i]
token_id = 0 token_id = 0
for block_idx in range(block_num - 1): for block_idx in range(block_num - 1):
cache[block_tables[i][block_idx]] = source[i][token_id : token_id + block_size].permute(1, 2, 0) cache[block_tables[i][block_idx]] = source[i][token_id : token_id + block_size].permute(1, 0, 2)
token_id += block_size token_id += block_size
cache[block_tables[i][block_num - 1], :, :, : seq_len - token_id] = source[i][token_id:seq_len].permute( cache[block_tables[i][block_num - 1], :, : seq_len - token_id, :] = source[i][token_id:seq_len].permute(
1, 2, 0 1, 0, 2
) )
elif type == "decoding": elif type == "decoding":
assert source.size(1) == 1, "seq_len should be equal to 1 when decoding." assert source.size(1) == 1, "seq_len should be equal to 1 when decoding."
source = source.squeeze(1) source = source.squeeze(1)
slot_idx = (lengths + block_size - 1) % block_size slot_idx = (lengths + block_size - 1) % block_size
for i in range(bsz): for i in range(bsz):
cache[block_tables[i, needed_blocks[i] - 1], :, :, slot_idx[i]] = source[i] cache[block_tables[i, needed_blocks[i] - 1], :, slot_idx[i], :] = source[i]
return cache return cache
@ -46,12 +46,12 @@ def convert_kvcache(cache, lengths, block_tables, pad_id=0):
""" """
Func: convert key/value cache for calculation Func: convert key/value cache for calculation
Args: cache: shape [num_blocks, num_heads, head_size, block_size] Args: cache: shape [num_blocks, num_heads, block_size, head_size]
lengths: key/value length lengths: key/value length
block_tables block_tables
pad_id: padded_id pad_id: padded_id
""" """
num_blocks, num_heads, head_size, block_size = cache.shape num_blocks, num_heads, block_size, head_size = cache.shape
needed_blocks = (lengths + block_size - 1) // block_size needed_blocks = (lengths + block_size - 1) // block_size
num_remaing_tokens = lengths % block_size num_remaing_tokens = lengths % block_size
@ -62,8 +62,8 @@ def convert_kvcache(cache, lengths, block_tables, pad_id=0):
for i in range(bsz): for i in range(bsz):
_cache = torch.cat( _cache = torch.cat(
( (
cache[block_tables[i][: needed_blocks[i] - 1]].permute((0, 3, 1, 2)).reshape(-1, num_heads, head_size), cache[block_tables[i][: needed_blocks[i] - 1]].permute((0, 2, 1, 3)).reshape(-1, num_heads, head_size),
cache[block_tables[i][needed_blocks[i] - 1], :, :, : num_remaing_tokens[i]].permute(2, 0, 1), cache[block_tables[i][needed_blocks[i] - 1], :, : num_remaing_tokens[i], :].permute(1, 0, 2),
), ),
dim=0, dim=0,
) )
@ -127,7 +127,7 @@ class PagedAttention:
q: torch.Tensor, # [num_tokens, num_heads, head_size] q: torch.Tensor, # [num_tokens, num_heads, head_size]
k: torch.Tensor, # [num_tokens, num_kv_heads, head_size] k: torch.Tensor, # [num_tokens, num_kv_heads, head_size]
v: torch.Tensor, v: torch.Tensor,
k_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size] k_cache: torch.Tensor, # [num_blocks, num_heads, block_size, head_size]
v_cache: torch.Tensor, v_cache: torch.Tensor,
context_lengths: torch.Tensor, # [num_seqs] context_lengths: torch.Tensor, # [num_seqs]
block_tables: torch.Tensor, # [num_seqs,max_blocks_per_sequence] block_tables: torch.Tensor, # [num_seqs,max_blocks_per_sequence]
@ -142,7 +142,7 @@ class PagedAttention:
assert num_heads % num_kv_heads == 0, "num_kv_heads should be divisible by num_heads" assert num_heads % num_kv_heads == 0, "num_kv_heads should be divisible by num_heads"
num_kv_groups = num_heads // num_kv_heads num_kv_groups = num_heads // num_kv_heads
block_size = k_cache.shape[-1] block_size = k_cache.size(-2)
bsz, max_blocks_per_sequence = block_tables.shape bsz, max_blocks_per_sequence = block_tables.shape
max_seq_len = max_blocks_per_sequence * block_size max_seq_len = max_blocks_per_sequence * block_size
assert q.shape[-1] == k.shape[-1] == v.shape[-1] assert q.shape[-1] == k.shape[-1] == v.shape[-1]
@ -196,7 +196,7 @@ class PagedAttention:
q: torch.Tensor, # [batch_size, seq_len, num_heads, head_size] q: torch.Tensor, # [batch_size, seq_len, num_heads, head_size]
k: torch.Tensor, # [batch_size, seq_len, num_kv_heads, head_size] k: torch.Tensor, # [batch_size, seq_len, num_kv_heads, head_size]
v: torch.Tensor, v: torch.Tensor,
k_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size] k_cache: torch.Tensor, # [num_blocks, num_heads, block_size, head_size]
v_cache: torch.Tensor, v_cache: torch.Tensor,
context_lengths: torch.Tensor, # [num_seqs] context_lengths: torch.Tensor, # [num_seqs]
block_tables: torch.Tensor, # [num_seqs,max_blocks_per_sequence] block_tables: torch.Tensor, # [num_seqs,max_blocks_per_sequence]
@ -207,7 +207,7 @@ class PagedAttention:
num_kv_heads = k.shape[-2] num_kv_heads = k.shape[-2]
assert num_heads % num_kv_heads == 0, "num_kv_heads should be divisible by num_heads" assert num_heads % num_kv_heads == 0, "num_kv_heads should be divisible by num_heads"
num_kv_groups = num_heads // num_kv_heads num_kv_groups = num_heads // num_kv_heads
block_size = k_cache.shape[-1] block_size = k_cache.size(-2)
assert q.shape[0] == k.shape[0] == v.shape[0] == block_tables.shape[0] assert q.shape[0] == k.shape[0] == v.shape[0] == block_tables.shape[0]
block_tables.shape[-1] * block_size block_tables.shape[-1] * block_size
@ -254,7 +254,7 @@ class PagedAttention:
q: torch.Tensor, # [bsz, 1, num_heads, head_size] q: torch.Tensor, # [bsz, 1, num_heads, head_size]
k: torch.Tensor, # [bsz, 1, num_kv_heads, head_size] k: torch.Tensor, # [bsz, 1, num_kv_heads, head_size]
v: torch.Tensor, v: torch.Tensor,
k_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size] k_cache: torch.Tensor, # [num_blocks, num_heads, block_size, head_size]
v_cache: torch.Tensor, v_cache: torch.Tensor,
lengths: torch.Tensor, # [num_seqs]: input_lengths + output_lengths lengths: torch.Tensor, # [num_seqs]: input_lengths + output_lengths
block_tables: torch.Tensor, # [num_seqs,max_blocks_per_sequence] block_tables: torch.Tensor, # [num_seqs,max_blocks_per_sequence]

View File

@ -171,7 +171,7 @@ def llama_attn_forward(
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1]) rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
_, _, _, block_size = k_cache.shape block_size = k_cache.size(-2)
if is_prompts: if is_prompts:
attn_output = context_attention_unpadded( attn_output = context_attention_unpadded(

View File

@ -226,7 +226,7 @@ def llama_attn_forward(
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1]) rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
_, _, _, block_size = k_cache.shape block_size = k_cache.size(-2)
if is_prompts: if is_prompts:
attn_output = context_attention_unpadded( attn_output = context_attention_unpadded(

View File

@ -36,8 +36,8 @@ def _fwd_context_paged_attention_kernel(
stride_od, stride_od,
stride_cacheb, stride_cacheb,
stride_cacheh, stride_cacheh,
stride_cached,
stride_cachebs, stride_cachebs,
stride_cached,
stride_bts, stride_bts,
stride_btb, stride_btb,
context_lengths, context_lengths,
@ -158,29 +158,29 @@ def _fwd_context_paged_attention_kernel(
# Copy k to corresponding cache block # Copy k to corresponding cache block
offsets_dmodel = tl.arange(0, HEAD_DIM) offsets_dmodel = tl.arange(0, HEAD_DIM)
offsets_kt = block_start_m * BLOCK_M + tl.arange(0, BLOCK_M) offsets_kt = block_start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offsets_k = K + offset_kv + offsets_dmodel[:, None] * stride_kd + offsets_kt[None, :] * stride_kt offsets_k = K + offset_kv + offsets_dmodel[None, :] * stride_kd + offsets_kt[:, None] * stride_kt
k = tl.load(offsets_k, mask=offsets_kt[None, :] < cur_seq_len, other=0.0) k = tl.load(offsets_k, mask=offsets_kt[:, None] < cur_seq_len, other=0.0)
offsets_kcachebs = tl.arange(0, BLOCK_SIZE) offsets_kcachebs = tl.arange(0, BLOCK_SIZE)
offsets_kcache = ( offsets_kcache = (
KCache KCache
+ offset_kvcache + offset_kvcache
+ offsets_dmodel[:, None] * stride_cached + offsets_dmodel[None, :] * stride_cached
+ offsets_kcachebs[None, :] * stride_cachebs + offsets_kcachebs[:, None] * stride_cachebs
) )
tl.store(offsets_kcache, k, mask=offsets_kcachebs[None, :] < cur_seq_len - block_start_m * BLOCK_SIZE) tl.store(offsets_kcache, k, mask=offsets_kcachebs[:, None] < cur_seq_len - block_start_m * BLOCK_SIZE)
# Copy v to corresponding cache block # Copy v to corresponding cache block
offsets_vd = offsets_dmodel offsets_vd = offsets_dmodel
offsets_vt = block_start_m * BLOCK_N + tl.arange(0, BLOCK_N) offsets_vt = block_start_m * BLOCK_N + tl.arange(0, BLOCK_N)
offsets_v = V + offset_kv + offsets_vt[:, None] * stride_vt + offsets_vd[None, :] * stride_vd offsets_v = V + offset_kv + offsets_vt[None, :] * stride_vt + offsets_vd[:, None] * stride_vd
v = tl.load(offsets_v, mask=offsets_vt[:, None] < cur_seq_len, other=0.0) v = tl.load(offsets_v, mask=offsets_vt[None, :] < cur_seq_len, other=0.0)
offsets_vcachebs = offsets_kcachebs # same block size range, just to notify here offsets_vcachebs = offsets_kcachebs # same block size range, just to notify here
offsets_vcache = ( offsets_vcache = (
VCache VCache
+ offset_kvcache + offset_kvcache
+ offsets_vcachebs[:, None] * stride_cachebs + offsets_vcachebs[None, :] * stride_cachebs
+ offsets_dmodel[None, :] * stride_cached + offsets_dmodel[:, None] * stride_cached
) )
tl.store(offsets_vcache, v, mask=offsets_vcachebs[:, None] < cur_seq_len - block_start_m * BLOCK_SIZE) tl.store(offsets_vcache, v, mask=offsets_vcachebs[None, :] < cur_seq_len - block_start_m * BLOCK_SIZE)
return return

View File

@ -10,8 +10,8 @@ import triton.language as tl
@triton.jit @triton.jit
def _flash_decoding_fwd_kernel( def _flash_decoding_fwd_kernel(
Q, # [batch_size, head_num, q_len(1), head_dim] Q, # [batch_size, head_num, q_len(1), head_dim]
KCache, # [num_blocks, num_kv_heads, head_dim, block_size] KCache, # [num_blocks, num_kv_heads, block_size, head_dim]
VCache, # [num_blocks, num_kv_heads, head_dim, block_size] VCache, # [num_blocks, num_kv_heads, block_size, head_dim]
block_tables, # [batch_size, max_blocks_per_sequence] block_tables, # [batch_size, max_blocks_per_sequence]
mid_o, # [batch_size, head_num, kv_split_num, head_dim] mid_o, # [batch_size, head_num, kv_split_num, head_dim]
mid_o_lse, # [batch_size, head_num, kv_split_num] mid_o_lse, # [batch_size, head_num, kv_split_num]
@ -22,8 +22,8 @@ def _flash_decoding_fwd_kernel(
stride_qd, stride_qd,
stride_cacheb, stride_cacheb,
stride_cacheh, stride_cacheh,
stride_cached,
stride_cachebs, stride_cachebs,
stride_cached,
stride_bts, stride_bts,
stride_btb, stride_btb,
stride_mid_ot, stride_mid_ot,
@ -79,18 +79,18 @@ def _flash_decoding_fwd_kernel(
K_block_ptr = tl.make_block_ptr( K_block_ptr = tl.make_block_ptr(
base=KCache + offset_kvcache, base=KCache + offset_kvcache,
shape=(HEAD_DIM, cur_occupied_size), shape=(cur_occupied_size, HEAD_DIM),
strides=(stride_cached, stride_cachebs), strides=(stride_cachebs, stride_cached),
offsets=(0, 0), offsets=(0, 0),
block_shape=(HEAD_DIM, BLOCK_SIZE), block_shape=(BLOCK_SIZE, HEAD_DIM),
order=(0, 1), order=(0, 1),
) )
V_block_ptr = tl.make_block_ptr( V_block_ptr = tl.make_block_ptr(
base=VCache + offset_kvcache, base=VCache + offset_kvcache,
shape=(HEAD_DIM, cur_occupied_size), shape=(cur_occupied_size, HEAD_DIM),
strides=(stride_cached, stride_cachebs), strides=(stride_cachebs, stride_cached),
offsets=(0, 0), offsets=(0, 0),
block_shape=(HEAD_DIM, BLOCK_SIZE), block_shape=(BLOCK_SIZE, HEAD_DIM),
order=(0, 1), order=(0, 1),
) )
k_cur_block = tl.load(K_block_ptr) k_cur_block = tl.load(K_block_ptr)
@ -102,7 +102,7 @@ def _flash_decoding_fwd_kernel(
# NOTE a trick to come across triton's requirement that values in both first and second input shapes must be >= 16, # NOTE a trick to come across triton's requirement that values in both first and second input shapes must be >= 16,
# Multiplying two tensors with shapes [1, d] * [d, block_size] will fail. # Multiplying two tensors with shapes [1, d] * [d, block_size] will fail.
# Refer to https://github.com/openai/triton/discussions/895 # Refer to https://github.com/openai/triton/discussions/895
S_ij += tl.sum(q[:, None] * k_cur_block, 0) S_ij += tl.sum(q[None, :] * k_cur_block, 1)
S_ij *= sm_scale S_ij *= sm_scale
S_ij += tl.where(block_start_kv * BLOCK_KV + tl.arange(0, BLOCK_SIZE) < cur_kv_seq_len, 0, float("-inf")) S_ij += tl.where(block_start_kv * BLOCK_KV + tl.arange(0, BLOCK_SIZE) < cur_kv_seq_len, 0, float("-inf"))
@ -111,7 +111,7 @@ def _flash_decoding_fwd_kernel(
p_ij_hat = tl.exp(S_ij) p_ij_hat = tl.exp(S_ij)
l = tl.sum(p_ij_hat, 0) l = tl.sum(p_ij_hat, 0)
p_ij_hat = p_ij_hat.to(v_cur_block.type.element_ty) p_ij_hat = p_ij_hat.to(v_cur_block.type.element_ty)
acc += tl.sum(v_cur_block * p_ij_hat[None, :], 1) acc += tl.sum(v_cur_block * p_ij_hat[:, None], 0)
acc = acc / l acc = acc / l
offsets_mid_o = ( offsets_mid_o = (
@ -206,8 +206,8 @@ def flash_decoding_attention(
Args: Args:
q (torch.Tensor): [bsz, num_heads, head_dim] q (torch.Tensor): [bsz, num_heads, head_dim]
k_cache (torch.Tensor): [num_blocks, num_kv_heads, head_dim, block_size] k_cache (torch.Tensor): [num_blocks, num_kv_heads, block_size, head_dim]
v_cache (torch.Tensor): [num_blocks, num_kv_heads, head_dim, block_size] v_cache (torch.Tensor): [num_blocks, num_kv_heads, block_size, head_dim]
kv_seq_len (torch.Tensor): [batch_size] kv_seq_len (torch.Tensor): [batch_size]
records the (kv) sequence lengths incorporating past kv sequence lengths. records the (kv) sequence lengths incorporating past kv sequence lengths.
block_tables (torch.Tensor): [batch_size, max_blocks_per_sequence] block_tables (torch.Tensor): [batch_size, max_blocks_per_sequence]
@ -230,13 +230,13 @@ def flash_decoding_attention(
assert head_dim in {32, 64, 128, 256} assert head_dim in {32, 64, 128, 256}
assert kv_seq_len.shape[0] == block_tables.shape[0] == bsz, ( assert kv_seq_len.shape[0] == block_tables.shape[0] == bsz, (
f"Got incompatible batch size (number of seqs):\n" f"Got incompatible batch size (number of seqs):\n"
f" KV seq lengths bsz {kv_seq_len.shape[0]}, Block tables bsz {block_tables.shape[0]}, " f" KV seq lengths bsz {kv_seq_len.size(0)}, Block tables bsz {block_tables.size(0)}, "
f"batch size {bsz}" f"batch size {bsz}"
) )
assert k_cache.size(-1) == v_cache.size(-1) == block_size, ( assert k_cache.size(-2) == v_cache.size(-2) == block_size, (
f"Got incompatible block size on kv caches:\n" f"Got incompatible block size on kv caches:\n"
f" assigned block_size {block_size}, k_cache block_size {k_cache.size(-1)}, " f" assigned block_size {block_size}, k_cache block_size {k_cache.size(-2)}, "
f"v_cache block_size {v_cache.size(-1)}" f"v_cache block_size {v_cache.size(-2)}"
) )
# NOTE BLOCK_KV could be considered as block splitting the sequence on k/v # NOTE BLOCK_KV could be considered as block splitting the sequence on k/v

View File

@ -15,8 +15,8 @@ def _copy_to_kvcache_seqlen1_kernel(
stride_kd, stride_kd,
stride_cacheb, stride_cacheb,
stride_cacheh, stride_cacheh,
stride_cached,
stride_cachebs, stride_cachebs,
stride_cached,
stride_bts, stride_bts,
stride_btb, stride_btb,
block_size, block_size,
@ -29,15 +29,15 @@ def _copy_to_kvcache_seqlen1_kernel(
last_bt_block_idx = past_kv_seq_len // block_size last_bt_block_idx = past_kv_seq_len // block_size
block_table_ptr = BLOCK_TABLES + cur_seq_idx * stride_bts block_table_ptr = BLOCK_TABLES + cur_seq_idx * stride_bts
block_id = tl.load(block_table_ptr + last_bt_block_idx * stride_btb) block_id = tl.load(block_table_ptr + last_bt_block_idx * stride_btb)
offsets_in_last_block = (past_kv_seq_len % block_size) * stride_cachebs offsets_in_last_block = past_kv_seq_len % block_size
offsets_dmodel = tl.arange(0, HEAD_DIM) offsets_dmodel = tl.arange(0, HEAD_DIM)
offsets_kv = cur_seq_idx * stride_kt + cur_kv_head_idx * stride_kh + offsets_dmodel * stride_kd offsets_kv = cur_seq_idx * stride_kt + cur_kv_head_idx * stride_kh + offsets_dmodel * stride_kd
kv = tl.load(KV + offsets_kv) kv = tl.load(KV + offsets_kv)
offsets_kvcache = ( offsets_kvcache = (
block_id * stride_cacheb block_id * stride_cacheb
+ cur_kv_head_idx * stride_cacheh + cur_kv_head_idx * stride_cacheh
+ offsets_in_last_block * stride_cachebs
+ offsets_dmodel * stride_cached + offsets_dmodel * stride_cached
+ offsets_in_last_block
) )
tl.store(KVCache + offsets_kvcache, kv) tl.store(KVCache + offsets_kvcache, kv)
return return
@ -52,23 +52,18 @@ def copy_kv_to_blocked_cache(
""" """
Copy keys or values to the blocked key/value cache during decoding stage. Copy keys or values to the blocked key/value cache during decoding stage.
Parameters: Args:
- k (torch.Tensor): [bsz, 1, num_kv_heads, head_dim]/[bsz, num_kv_heads, head_dim] - Keys or values during decoding with seq len 1. k (torch.Tensor): [bsz, 1, num_kv_heads, head_dim]/[bsz, num_kv_heads, head_dim] - Keys or values during decoding with seq len 1.
- k_cache (torch.Tensor): [num_blocks, num_kv_heads, head_dim, block_size] - Blocked key or value cache. k_cache (torch.Tensor): [num_blocks, num_kv_heads, block_size, head_dim] - Blocked key or value cache.
- kv_lengths (torch.Tensor): [bsz] - Past key/value sequence lengths plus current sequence length for each sequence. kv_lengths (torch.Tensor): [bsz] - Past key/value sequence lengths plus current sequence length for each sequence.
- block_tables (torch.Tensor): [bsz, max_blocks_per_sequence] - Block tables for each sequence. block_tables (torch.Tensor): [bsz, max_blocks_per_sequence] - Block tables for each sequence.
""" """
assert k.size(-1) == k_cache.size(-2), "Incompatible head dim" assert k.size(-1) == k_cache.size(-1), "Incompatible head dim"
assert k.dtype == k_cache.dtype, "Expected consistent dtype for tensor and cache." assert k.dtype == k_cache.dtype, "Expected consistent dtype for tensor and cache."
if k.dim() == 4:
assert k.size(1) == 1, "Unsupported kv seq len (supposed to be used for decoding stage)" k = k.squeeze(1) if k.dim() == 4 else k
bsz, _, num_kv_heads, head_dim = k.shape assert k.dim() == 3, f"Incompatible k dim {k.dim()}"
# [bsz, 1, num_kv_heads, head_dim] -> [bsz, num_kv_heads, head_dim] bsz, num_kv_heads, head_dim = k.shape
k = k.squeeze(dim=1)
elif k.dim() == 3:
bsz, num_kv_heads, head_dim = k.shape
else:
raise ValueError(f"The key dim should be 3 or 4, but got {k.dim()}.")
assert kv_lengths.shape[0] == block_tables.shape[0] == bsz, ( assert kv_lengths.shape[0] == block_tables.shape[0] == bsz, (
f"Got incompatible batch size (number of seqs):\n" f"Got incompatible batch size (number of seqs):\n"
@ -77,7 +72,7 @@ def copy_kv_to_blocked_cache(
) )
# Modify if the shape of kv cahce is changed. # Modify if the shape of kv cahce is changed.
block_size = k_cache.size(-1) block_size = k_cache.size(-2)
num_warps = 8 if head_dim > 128 else 4 num_warps = 8 if head_dim > 128 else 4

View File

@ -93,7 +93,7 @@ def check_cache_manager(test_config):
assert len(cache_manager._cache_blocks) == num_blocks assert len(cache_manager._cache_blocks) == num_blocks
key_caches = cache_manager._kv_caches[0] # key caches for all the blocks in all the layers key_caches = cache_manager._kv_caches[0] # key caches for all the blocks in all the layers
assert len(key_caches) == num_layers assert len(key_caches) == num_layers
expected_kv_shape = (num_blocks, num_attention_heads, head_size, block_size) expected_kv_shape = (num_blocks, num_attention_heads, block_size, head_size)
assert key_caches[0].shape == expected_kv_shape assert key_caches[0].shape == expected_kv_shape
k_cache_block0, v_cache_block0 = cache_manager.get_physical_cache(0, 0) k_cache_block0, v_cache_block0 = cache_manager.get_physical_cache(0, 0)
expected_kv_block_shape = expected_kv_shape[1:] expected_kv_block_shape = expected_kv_shape[1:]

View File

@ -1,20 +1,17 @@
import pytest
import torch import torch
from transformers.cache_utils import DynamicCache from transformers.cache_utils import DynamicCache
from transformers.modeling_attn_mask_utils import AttentionMaskConverter from transformers.modeling_attn_mask_utils import AttentionMaskConverter
from transformers.models.llama.configuration_llama import LlamaConfig from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb
import colossalai
from colossalai.inference.modeling.layers.attention import PagedAttention, convert_kvcache, copy_to_cache from colossalai.inference.modeling.layers.attention import PagedAttention, convert_kvcache, copy_to_cache
from colossalai.testing import rerun_if_address_is_in_use, spawn
def test_copy_to_cache(): def test_copy_to_cache():
key = torch.ones((2, 11, 3, 3)) key = torch.ones((2, 11, 3, 3))
key[0, 9, :, :] = 0 key[0, 9, :, :] = 0
key[1, -2:, :, :] = 0 key[1, -2:, :, :] = 0
cache = torch.zeros(8, 3, 3, 8) cache = torch.zeros(8, 3, 8, 3)
block_tables = torch.tensor([[0, 1], [2, 3]]) block_tables = torch.tensor([[0, 1], [2, 3]])
lengths = torch.tensor([9, 8]) lengths = torch.tensor([9, 8])
cache = copy_to_cache(key, cache=cache, lengths=lengths, block_tables=block_tables, type="prefill") cache = copy_to_cache(key, cache=cache, lengths=lengths, block_tables=block_tables, type="prefill")
@ -28,7 +25,7 @@ def test_copy_to_cache():
def test_convert_kvcache(): def test_convert_kvcache():
cache = torch.ones(8, 3, 3, 8) cache = torch.ones(8, 3, 8, 3)
key = torch.ones(2, 1, 3, 3) + 1 key = torch.ones(2, 1, 3, 3) + 1
lengths = torch.tensor([10, 9]) lengths = torch.tensor([10, 9])
block_tables = torch.tensor([[0, 1], [2, 3]]) block_tables = torch.tensor([[0, 1], [2, 3]])
@ -43,8 +40,8 @@ def test_context_attention():
""" """
attn = PagedAttention() attn = PagedAttention()
q = k = v = torch.randn(8, 4, 4) q = k = v = torch.randn(8, 4, 4)
k_cache = torch.empty(8, 4, 4, 8) k_cache = torch.empty(8, 4, 8, 4)
v_cache = torch.empty(8, 4, 4, 8) v_cache = torch.empty(8, 4, 8, 4)
context_lengths = torch.tensor( context_lengths = torch.tensor(
[ [
8, 8,
@ -136,23 +133,8 @@ def test_decoding_attention():
assert torch.allclose(pad_attn_output, attn_output, atol=1e-3, rtol=1e-2) assert torch.allclose(pad_attn_output, attn_output, atol=1e-3, rtol=1e-2)
def check_attention_layer(): if __name__ == "__main__":
test_copy_to_cache() test_copy_to_cache()
test_convert_kvcache() test_convert_kvcache()
test_context_attention() test_context_attention()
test_decoding_attention() test_decoding_attention()
def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost")
check_attention_layer()
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_attention_layer():
spawn(run_dist, 1)
if __name__ == "__main__":
test_attention_layer()

View File

@ -106,6 +106,40 @@ def mock_alloc_block_table_and_kvcache(
return block_tables return block_tables
def mock_alloc_block_table_and_kvcache_v2(
k: torch.Tensor,
v: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
context_lengths: torch.Tensor,
num_seqs: int,
max_num_blocks_per_seq: int,
block_size: int,
) -> torch.Tensor:
"""Allocate block tables based on provided context lengths; and copy KV to blocked KV Cache."""
block_id = 0
block_tables = torch.full(size=(num_seqs, max_num_blocks_per_seq), fill_value=-1, dtype=torch.int32)
num_tokens_processed = 0
for i, seq_len in enumerate(context_lengths.tolist()):
right_bound = (seq_len + block_size - 1) // block_size # open bound
block_tables[i, :right_bound] = torch.arange(block_id, block_id + right_bound, dtype=torch.int32)
# Manually fill kv caches by copying from k and v
for i in range(right_bound):
if i == right_bound - 1:
allocated_locs = seq_len % block_size or block_size
else:
allocated_locs = block_size
k_block = k[num_tokens_processed : num_tokens_processed + allocated_locs, :, :].permute(1, 0, 2)
v_block = v[num_tokens_processed : num_tokens_processed + allocated_locs, :, :].permute(1, 0, 2)
k_cache[block_id, :, :allocated_locs, :] = k_block
v_cache[block_id, :, :allocated_locs, :] = v_block
num_tokens_processed += allocated_locs
block_id += 1
return block_tables
def mock_alloc_single_token(block_tables: torch.Tensor, context_lengths: torch.Tensor, block_size: int) -> None: def mock_alloc_single_token(block_tables: torch.Tensor, context_lengths: torch.Tensor, block_size: int) -> None:
# Allocate 1 token on the block table for each seqs in block tables. # Allocate 1 token on the block table for each seqs in block tables.
# It won't change provided context_lengths. # It won't change provided context_lengths.
@ -146,6 +180,22 @@ def generate_caches_and_block_tables(
return k_cache, v_cache, block_tables return k_cache, v_cache, block_tables
def generate_caches_and_block_tables_v2(
k_unpad, v_unpad, kv_lengths, bsz, max_num_blocks_per_seq, block_size, dtype=torch.float16, device="cuda"
) -> Tuple[torch.Tensor, ...]:
# Mock generation of k/v blocked caches and block tables from providied kv unpad and seq lengths
# k_unpad/v_unpad [num_total_tokens, num_kv_heads, head_dim]
_, num_kv_heads, head_dim = k_unpad.shape
cache_shape = (bsz * max_num_blocks_per_seq, num_kv_heads, block_size, head_dim)
k_cache = torch.zeros(size=cache_shape, dtype=dtype, device=device)
v_cache = torch.zeros(size=cache_shape, dtype=dtype, device=device)
# Mock allocation on block tables as well as blocked kv caches
block_tables = mock_alloc_block_table_and_kvcache_v2(
k_unpad, v_unpad, k_cache, v_cache, kv_lengths, bsz, max_num_blocks_per_seq, block_size
)
return k_cache, v_cache, block_tables
def convert_kv_unpad_to_padded( def convert_kv_unpad_to_padded(
k_unpad: torch.Tensor, kv_seq_lengths: torch.Tensor, bsz: int, max_seq_len: int k_unpad: torch.Tensor, kv_seq_lengths: torch.Tensor, bsz: int, max_seq_len: int
) -> torch.Tensor: ) -> torch.Tensor:

View File

@ -6,7 +6,7 @@ from transformers.modeling_attn_mask_utils import AttentionMaskConverter
from colossalai.inference.modeling.layers.attention import PagedAttention from colossalai.inference.modeling.layers.attention import PagedAttention
from colossalai.kernel.triton import context_attention_unpadded from colossalai.kernel.triton import context_attention_unpadded
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from tests.test_infer_ops.triton.kernel_utils import generate_caches_and_block_tables, torch_attn_ref from tests.test_infer_ops.triton.kernel_utils import generate_caches_and_block_tables_v2, torch_attn_ref
try: try:
import triton # noqa import triton # noqa
@ -93,7 +93,7 @@ def test_context_attention(
q_unpad, k_unpad, v_unpad = torch.split(qkv_unpad, [num_attn_heads, num_kv_heads, num_kv_heads], dim=-2) q_unpad, k_unpad, v_unpad = torch.split(qkv_unpad, [num_attn_heads, num_kv_heads, num_kv_heads], dim=-2)
q_unpad = q_unpad.contiguous() q_unpad = q_unpad.contiguous()
k_cache_ref, v_cache_ref, block_tables = generate_caches_and_block_tables( k_cache_ref, v_cache_ref, block_tables = generate_caches_and_block_tables_v2(
k_unpad, v_unpad, context_lengths, bsz, max_num_blocks_per_seq, block_size, dtype, device k_unpad, v_unpad, context_lengths, bsz, max_num_blocks_per_seq, block_size, dtype, device
) )
block_tables = block_tables.to(device=device) block_tables = block_tables.to(device=device)
@ -148,7 +148,6 @@ def bench_kernel(
num_kv_heads = num_attn_heads // kv_group_num num_kv_heads = num_attn_heads // kv_group_num
assert isinstance(num_kv_heads, int) and num_kv_heads > 0, "Invalid number of kv heads." assert isinstance(num_kv_heads, int) and num_kv_heads > 0, "Invalid number of kv heads."
block_size * max_num_blocks_per_seq
dtype = torch.float16 dtype = torch.float16
device = get_current_device() device = get_current_device()
@ -162,7 +161,7 @@ def bench_kernel(
qkv_unpad = torch.empty(size=qkv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5) qkv_unpad = torch.empty(size=qkv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5)
q_unpad, k_unpad, v_unpad = torch.split(qkv_unpad, [num_attn_heads, num_kv_heads, num_kv_heads], dim=-2) q_unpad, k_unpad, v_unpad = torch.split(qkv_unpad, [num_attn_heads, num_kv_heads, num_kv_heads], dim=-2)
q_unpad = q_unpad.contiguous() q_unpad = q_unpad.contiguous()
k_cache_ref, v_cache_ref, block_tables = generate_caches_and_block_tables( k_cache_ref, v_cache_ref, block_tables = generate_caches_and_block_tables_v2(
k_unpad, v_unpad, context_lengths, bsz, max_num_blocks_per_seq, block_size, dtype, device k_unpad, v_unpad, context_lengths, bsz, max_num_blocks_per_seq, block_size, dtype, device
) )
block_tables = block_tables.to(device=device) block_tables = block_tables.to(device=device)

View File

@ -6,7 +6,7 @@ from colossalai.kernel.triton import flash_decoding_attention
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from tests.test_infer_ops.triton.kernel_utils import ( from tests.test_infer_ops.triton.kernel_utils import (
convert_kv_unpad_to_padded, convert_kv_unpad_to_padded,
generate_caches_and_block_tables, generate_caches_and_block_tables_v2,
prepare_padding_mask, prepare_padding_mask,
torch_attn_ref, torch_attn_ref,
) )
@ -38,6 +38,9 @@ def prepare_data(
): ):
# Use the provided maximum sequence length for each sequence when testing with teh same context length, # Use the provided maximum sequence length for each sequence when testing with teh same context length,
# otherwise generate random context lengths. # otherwise generate random context lengths.
# returns
# q [bsz, num_attn_heads, q_len, head_dim]
# k_unpad/v_unpad [num_tokens, num_kv_heads, head_dim]
kv_lengths = ( kv_lengths = (
torch.tensor([max_kv_seq_len for _ in range(bsz)], dtype=torch.int32, device=device) torch.tensor([max_kv_seq_len for _ in range(bsz)], dtype=torch.int32, device=device)
if same_context_len if same_context_len
@ -83,7 +86,7 @@ def test_flash_decoding(
q, k_unpad, v_unpad, kv_seq_lengths = prepare_data( q, k_unpad, v_unpad, kv_seq_lengths = prepare_data(
bsz, num_attn_heads, num_kv_heads, HEAD_DIM, same_context_len, Q_LEN, max_seq_len, dtype, device bsz, num_attn_heads, num_kv_heads, HEAD_DIM, same_context_len, Q_LEN, max_seq_len, dtype, device
) )
k_cache, v_cache, block_tables = generate_caches_and_block_tables( k_cache, v_cache, block_tables = generate_caches_and_block_tables_v2(
k_unpad, v_unpad, kv_seq_lengths, bsz, max_num_blocks_per_seq, block_size, dtype, device k_unpad, v_unpad, kv_seq_lengths, bsz, max_num_blocks_per_seq, block_size, dtype, device
) )
block_tables = block_tables.to(device=device) block_tables = block_tables.to(device=device)
@ -180,7 +183,7 @@ def bench_kernel(
) )
ms, min_ms, max_ms = triton.testing.do_bench(fn, warmup=WARM_UPS, rep=REPS, quantiles=quantiles) ms, min_ms, max_ms = triton.testing.do_bench(fn, warmup=WARM_UPS, rep=REPS, quantiles=quantiles)
if provider == "triton": if provider == "triton":
k_cache, v_cache, block_tables = generate_caches_and_block_tables( k_cache, v_cache, block_tables = generate_caches_and_block_tables_v2(
k_unpad, v_unpad, kv_lengths, bsz, max_num_blocks_per_seq, block_size, dtype, device k_unpad, v_unpad, kv_lengths, bsz, max_num_blocks_per_seq, block_size, dtype, device
) )
block_tables = block_tables.to(device=device) block_tables = block_tables.to(device=device)

View File

@ -5,7 +5,7 @@ from packaging import version
from colossalai.inference.modeling.layers.attention import copy_to_cache from colossalai.inference.modeling.layers.attention import copy_to_cache
from colossalai.kernel.triton import copy_kv_to_blocked_cache from colossalai.kernel.triton import copy_kv_to_blocked_cache
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from tests.test_infer_ops.triton.kernel_utils import mock_alloc_block_table_and_kvcache, mock_alloc_single_token from tests.test_infer_ops.triton.kernel_utils import generate_caches_and_block_tables_v2, mock_alloc_single_token
try: try:
import triton # noqa import triton # noqa
@ -17,6 +17,8 @@ except ImportError:
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4") TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4")
HEAD_DIM = 128
def prepare_data( def prepare_data(
bsz, bsz,
@ -29,31 +31,27 @@ def prepare_data(
device, device,
dtype=torch.float16, dtype=torch.float16,
): ):
if same_context_len: # past_kv_seq_lengths in this test records the previous kv seq len
# past_kv_seq_lengths in this test records the previous kv seq len # (not incorporating the current input whose seq len is 1)
# (not incorporating the current input whose seq len is 1) past_kv_seq_lengths = (
past_kv_seq_lengths = torch.tensor([max_seq_len - 1 for _ in range(bsz)], dtype=torch.int32, device=device) torch.tensor([max_seq_len - 1 for _ in range(bsz)], dtype=torch.int32, device=device)
else: if same_context_len
past_kv_seq_lengths = torch.randint(low=1, high=max_seq_len - 1, size=(bsz,), dtype=torch.int32, device=device) else torch.randint(low=1, high=max_seq_len - 1, size=(bsz,), dtype=torch.int32, device=device)
)
num_tokens = torch.sum(past_kv_seq_lengths).item() num_tokens = torch.sum(past_kv_seq_lengths).item()
kv_size = (num_tokens, 2 * num_kv_heads, head_dim) kv_size = (num_tokens, 2 * num_kv_heads, head_dim)
kv = torch.empty(size=kv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5) kv_unpad = torch.empty(size=kv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5)
k, v = torch.split(kv, [num_kv_heads, num_kv_heads], dim=-2) k_unpad, v_unpad = torch.split(kv_unpad, [num_kv_heads, num_kv_heads], dim=-2)
cache_shape = (bsz * max_num_blocks_per_seq, num_kv_heads, head_dim, block_size) k_cache, _, block_tables = generate_caches_and_block_tables_v2(
k_cache = torch.zeros(size=cache_shape, dtype=dtype, device=device) k_unpad, v_unpad, past_kv_seq_lengths, bsz, max_num_blocks_per_seq, block_size, dtype=dtype, device=device
v_cache = torch.zeros(size=cache_shape, dtype=dtype, device=device)
# Mock allocation on block tables as well as blocked kv caches
block_tables = mock_alloc_block_table_and_kvcache(
k, v, k_cache, v_cache, past_kv_seq_lengths, bsz, max_num_blocks_per_seq, block_size
) )
block_tables = block_tables.to(device=device) block_tables = block_tables.to(device=device)
new_k = torch.randn((bsz, 1, num_kv_heads, head_dim), dtype=dtype, device=device) new_k = torch.randn((bsz, 1, num_kv_heads, head_dim), dtype=dtype, device=device)
# mock allocating blocks for the new k/v and update block tables # mock allocating blocks for the new k/v and update block tables
mock_alloc_single_token(block_tables, past_kv_seq_lengths, block_size) mock_alloc_single_token(block_tables, past_kv_seq_lengths, block_size)
# kv seq len = past kv seq len + seq len (1 during decoding stage) # kv seq len = past kv seq len + seq len (1 during decoding stage)
kv_seq_lengths = past_kv_seq_lengths + 1 kv_seq_lengths = past_kv_seq_lengths + 1
@ -78,7 +76,6 @@ def test_copy_kv_to_caches(
torch.cuda.synchronize() torch.cuda.synchronize()
torch.cuda.reset_peak_memory_stats() torch.cuda.reset_peak_memory_stats()
head_dim = 128
max_seq_len = block_size * max_num_blocks_per_seq max_seq_len = block_size * max_num_blocks_per_seq
dtype = torch.float16 dtype = torch.float16
device = get_current_device() device = get_current_device()
@ -86,7 +83,7 @@ def test_copy_kv_to_caches(
new_k, k_cache, kv_seq_lengths, block_tables = prepare_data( new_k, k_cache, kv_seq_lengths, block_tables = prepare_data(
bsz, bsz,
num_kv_heads, num_kv_heads,
head_dim, HEAD_DIM,
block_size, block_size,
max_num_blocks_per_seq, max_num_blocks_per_seq,
same_context_len, same_context_len,
@ -94,20 +91,28 @@ def test_copy_kv_to_caches(
device=device, device=device,
dtype=dtype, dtype=dtype,
) )
# k_cache_torch = k_cache.clone().detach()
# copy_to_cache(new_k, k_cache_torch, lengths=kv_seq_lengths, block_tables=block_tables, type="decoding")
copy_kv_to_blocked_cache(new_k, k_cache, kv_seq_lengths, block_tables) copy_kv_to_blocked_cache(new_k, k_cache, kv_seq_lengths, block_tables)
for seq_i in range(bsz): past_kv_seq_len = kv_seq_lengths - 1
ki = new_k[seq_i] target_block_ids = block_tables[range(0, block_tables.size(0)), past_kv_seq_len // block_size]
ki = ki.squeeze() offsets_in_block = past_kv_seq_len % block_size
past_kv_seq_len = kv_seq_lengths[seq_i] - 1 target = k_cache[target_block_ids, :, offsets_in_block, :]
target_block_id = block_tables[seq_i, past_kv_seq_len // block_size] source = new_k.squeeze()
offsets_in_block = past_kv_seq_len % block_size
target = k_cache[target_block_id, :, :, offsets_in_block] assert target.shape == source.shape
orig = new_k[seq_i].squeeze(dim=0) assert torch.equal(target, source)
assert torch.equal(orig, target) # target_torch = k_cache_copy[target_block_ids, :, offsets_in_block, :]
# assert target_torch.shape == source.shape
# assert torch.equal(target_torch, source)
BATCH = 16 BATCH = 16
BLOCK_SIZE = 32
SAME_LEN = True
WARM_UPS = 10
REPS = 100
configs = [ configs = [
triton.testing.Benchmark( triton.testing.Benchmark(
x_names=["KV_SEQ_LEN"], x_names=["KV_SEQ_LEN"],
@ -133,10 +138,6 @@ def benchmark_kvcache_copy(
num_kv_heads: int, num_kv_heads: int,
same_context_len: bool, same_context_len: bool,
): ):
warmup = 10
rep = 100
head_dim = 128
dtype = torch.float16 dtype = torch.float16
device = get_current_device() device = get_current_device()
@ -145,7 +146,7 @@ def benchmark_kvcache_copy(
new_k, k_cache, context_lengths, block_tables = prepare_data( new_k, k_cache, context_lengths, block_tables = prepare_data(
bsz, bsz,
num_kv_heads, num_kv_heads,
head_dim, HEAD_DIM,
block_size, block_size,
max_seq_len // block_size, max_seq_len // block_size,
same_context_len, same_context_len,
@ -154,15 +155,14 @@ def benchmark_kvcache_copy(
dtype=dtype, dtype=dtype,
) )
quantiles = [0.5, 0.2, 0.8]
if provider == "torch_copy_func": if provider == "torch_copy_func":
fn = lambda: copy_to_cache(new_k, k_cache, lengths=context_lengths, block_tables=block_tables, type="decoding") fn = lambda: copy_to_cache(new_k, k_cache, lengths=context_lengths, block_tables=block_tables, type="decoding")
elif provider == "triton_copy_func": if provider == "triton_copy_func":
fn = lambda: copy_kv_to_blocked_cache(new_k, k_cache, context_lengths, block_tables) fn = lambda: copy_kv_to_blocked_cache(new_k, k_cache, context_lengths, block_tables)
else:
raise ValueError("Undefined provider.")
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) ms, min_ms, max_ms = triton.testing.do_bench(fn, warmup=WARM_UPS, rep=REPS, quantiles=quantiles)
return ms return ms, min_ms, max_ms
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -3,7 +3,6 @@ import torch
import triton import triton
from packaging import version from packaging import version
from transformers.models.llama.modeling_llama import LlamaRMSNorm from transformers.models.llama.modeling_llama import LlamaRMSNorm
from vllm.model_executor.layers.layernorm import RMSNorm
from colossalai.kernel.triton import rms_layernorm from colossalai.kernel.triton import rms_layernorm
from colossalai.testing.utils import parameterize from colossalai.testing.utils import parameterize
@ -36,7 +35,8 @@ def test_layer_norm(M, N):
y_triton = rms_layernorm(x, weight, eps=eps) y_triton = rms_layernorm(x, weight, eps=eps)
y_llama = rms_norm.forward(x).to(dtype) y_llama = rms_norm.forward(x).to(dtype)
assert torch.allclose(y_triton, y_llama, atol=1e-5, rtol=1e-5) assert y_triton.shape == y_llama.shape
assert torch.allclose(y_triton, y_llama, atol=1e-5, rtol=1e-3)
# Triton benchmark plot attributions # Triton benchmark plot attributions
@ -45,8 +45,8 @@ configs = [
x_names=["SEQUENCE_TOTAL"], x_names=["SEQUENCE_TOTAL"],
x_vals=[i for i in range(128, 1025, 128)], x_vals=[i for i in range(128, 1025, 128)],
line_arg="provider", line_arg="provider",
line_vals=["vllm_rms_layernorm", "triton_rms_layernorm"], line_vals=["torch_rms_layernorm", "triton_rms_layernorm"],
line_names=["vllm_rms_layernorm", "triton_rms_layernorm"], line_names=["torch_rms_layernorm", "triton_rms_layernorm"],
styles=[("red", "-"), ("blue", "-")], styles=[("red", "-"), ("blue", "-")],
ylabel="ms", ylabel="ms",
plot_name=f"RMSNorm benchmarking results", plot_name=f"RMSNorm benchmarking results",
@ -69,10 +69,10 @@ def benchmark_rms_layernorm(
x_shape = (SEQUENCE_TOTAL, HIDDEN_SIZE) x_shape = (SEQUENCE_TOTAL, HIDDEN_SIZE)
w_shape = (x_shape[-1],) w_shape = (x_shape[-1],)
weight = torch.ones(w_shape, dtype=dtype, device="cuda") weight = torch.ones(w_shape, dtype=dtype, device="cuda")
vllm_norm = RMSNorm(hidden_size=HIDDEN_SIZE).to(dtype=dtype, device="cuda") torch_norm = LlamaRMSNorm(hidden_size=HIDDEN_SIZE).to(dtype=dtype, device="cuda")
x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda") x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda")
if provider == "vllm_rms_layernorm": if provider == "torch_rms_layernorm":
fn = lambda: vllm_norm(x) fn = lambda: torch_norm(x)
elif provider == "triton_rms_layernorm": elif provider == "triton_rms_layernorm":
fn = lambda: rms_layernorm(x, weight, eps=eps) fn = lambda: rms_layernorm(x, weight, eps=eps)
else: else: