[Inference/Kernel] refactor kvcache manager and rotary_embedding and kvcache_memcpy oper… (#5663)

* refactor kvcache manager and rotary_embedding and kvcache_memcpy operator

* refactor decode_kv_cache_memcpy

* enable alibi in pagedattention
This commit is contained in:
Steve Luo 2024-04-30 15:52:23 +08:00 committed by GitHub
parent 5f00002e43
commit 5cd75ce4c7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 368 additions and 235 deletions

View File

@ -90,9 +90,18 @@ class KVCacheManager:
self.num_blocks = self.max_blocks_per_sequence * self.max_batch_size * self.beam_width self.num_blocks = self.max_blocks_per_sequence * self.max_batch_size * self.beam_width
# Physical cache allocation # Physical cache allocation
if config.use_cuda_kernel:
x = 16 // torch.tensor([], dtype=config.dtype).element_size()
kalloc_shape = (self.num_blocks, self.kv_head_num, self.head_size // x, self.block_size, x)
valloc_shape = (self.num_blocks, self.kv_head_num, self.block_size, self.head_size)
self.logger.info(
f"Allocating K cache with shape: {kalloc_shape}, V cache with shape: {valloc_shape} consisting of {self.num_blocks} blocks."
)
self._kv_caches = self._init_device_caches(kalloc_shape, valloc_shape)
else:
alloc_shape = (self.num_blocks, self.kv_head_num, self.block_size, self.head_size) alloc_shape = (self.num_blocks, self.kv_head_num, self.block_size, self.head_size)
self.logger.info(f"Allocating KV cache with shape: {alloc_shape} consisting of {self.num_blocks} blocks.") self.logger.info(f"Allocating KV cache with shape: {alloc_shape} consisting of {self.num_blocks} blocks.")
self._kv_caches = self._init_device_caches(alloc_shape) self._kv_caches = self._init_device_caches(alloc_shape, alloc_shape)
self.total_physical_cache_size_in_bytes = ( self.total_physical_cache_size_in_bytes = (
self.elem_size_in_bytes self.elem_size_in_bytes
* self.num_layers * self.num_layers
@ -479,7 +488,9 @@ class KVCacheManager:
blocks.append(cache_block) blocks.append(cache_block)
return blocks return blocks
def _init_device_caches(self, alloc_shape: Tuple[int, ...]) -> Tuple[torch.Tensor, torch.Tensor]: def _init_device_caches(
self, kalloc_shape: Tuple[int, ...], valloc_shape: Tuple[int, ...]
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Initialize the physical cache on the device. """Initialize the physical cache on the device.
For each layer of the model, we allocate two tensors for key and value respectively, For each layer of the model, we allocate two tensors for key and value respectively,
@ -488,6 +499,6 @@ class KVCacheManager:
k_cache: List[torch.Tensor] = [] k_cache: List[torch.Tensor] = []
v_cache: List[torch.Tensor] = [] v_cache: List[torch.Tensor] = []
for _ in range(self.num_layers): for _ in range(self.num_layers):
k_cache.append(torch.zeros(alloc_shape, dtype=self.dtype, device=self.device)) k_cache.append(torch.zeros(kalloc_shape, dtype=self.dtype, device=self.device))
v_cache.append(torch.zeros(alloc_shape, dtype=self.dtype, device=self.device)) v_cache.append(torch.zeros(valloc_shape, dtype=self.dtype, device=self.device))
return k_cache, v_cache return k_cache, v_cache

View File

@ -310,6 +310,7 @@ class NopadBaichuanAttention(ParallelModule):
alibi_slopes=self.alibi_slopes, alibi_slopes=self.alibi_slopes,
max_seq_len=kv_seq_len, max_seq_len=kv_seq_len,
sm_scale=sm_scale, sm_scale=sm_scale,
use_new_kcache_layout=use_cuda_kernel,
) )
else: else:
q_len = tokens_to_verify + 1 if is_verifier else 1 q_len = tokens_to_verify + 1 if is_verifier else 1
@ -332,6 +333,21 @@ class NopadBaichuanAttention(ParallelModule):
inference_ops.decode_kv_cache_memcpy( inference_ops.decode_kv_cache_memcpy(
key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables
) )
inference_ops.flash_decoding_attention(
output_tensor,
query_states,
k_cache,
v_cache,
sequence_lengths,
block_tables,
block_size,
kv_seq_len,
fd_inter_tensor.mid_output,
fd_inter_tensor.mid_output_lse,
self.alibi_slopes,
sm_scale,
)
attn_output = output_tensor
else: else:
if not is_verifier and not self.use_alibi_attn: if not is_verifier and not self.use_alibi_attn:
decoding_fused_rotary_embedding( decoding_fused_rotary_embedding(

View File

@ -98,15 +98,8 @@ def llama_model_forward(
""" """
block_tables = inputmetadata.block_tables block_tables = inputmetadata.block_tables
sequence_lengths = inputmetadata.sequence_lengths sequence_lengths = inputmetadata.sequence_lengths
batch_size = inputmetadata.batch_size
kv_seq_len = inputmetadata.kv_seq_len kv_seq_len = inputmetadata.kv_seq_len
# NOTE: After testing, the performance of this configuration is relatively good. With updates
# and optimizations to the CUDA kernel implementation, a more detailed analysis of this configuration's
# selection should be conducted.
if batch_size >= 32 and kv_seq_len > 512:
use_cuda_kernel = False
# NOTE (yuanheng-zhao): fow now, only triton kernels support verification process # NOTE (yuanheng-zhao): fow now, only triton kernels support verification process
# during speculative-decoding (`q_len > 1`) # during speculative-decoding (`q_len > 1`)
# We will expicitly disable `use_cuda_kernel` here when speculative-decoding is enabled # We will expicitly disable `use_cuda_kernel` here when speculative-decoding is enabled
@ -575,6 +568,7 @@ class NopadLlamaAttention(ParallelModule, LlamaAttention):
output=output_tensor, output=output_tensor,
max_seq_len=kv_seq_len, max_seq_len=kv_seq_len,
sm_scale=sm_scale, sm_scale=sm_scale,
use_new_kcache_layout=use_cuda_kernel,
) )
else: else:
q_len = tokens_to_verify + 1 if is_verifier else 1 q_len = tokens_to_verify + 1 if is_verifier else 1
@ -592,20 +586,21 @@ class NopadLlamaAttention(ParallelModule, LlamaAttention):
block_tables, block_tables,
high_precision, high_precision,
) )
# inference_ops.flash_decoding_attention( inference_ops.flash_decoding_attention(
# output_tensor, output_tensor,
# query_states, query_states,
# k_cache, k_cache,
# v_cache, v_cache,
# sequence_lengths, sequence_lengths,
# block_tables, block_tables,
# block_size, block_size,
# kv_seq_len, kv_seq_len,
# fd_inter_tensor.mid_output, fd_inter_tensor.mid_output,
# fd_inter_tensor.mid_output_lse, fd_inter_tensor.mid_output_lse,
# sm_scale, None,
# ) sm_scale,
# attn_output = output_tensor )
attn_output = output_tensor
else: else:
if is_verifier: if is_verifier:
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1]) rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])

View File

@ -20,7 +20,7 @@ inference_ops = InferenceOpsLoader().load()
configs = [ configs = [
triton.testing.Benchmark( triton.testing.Benchmark(
x_names=["MAX_NUM_BLOCKS_PER_SEQ"], x_names=["MAX_NUM_BLOCKS_PER_SEQ"],
x_vals=[2**i for i in range(3, 8)], x_vals=[2**i for i in range(2, 8)],
line_arg="provider", line_arg="provider",
line_vals=[ line_vals=[
"vllm_paged_decoding_attention", "vllm_paged_decoding_attention",
@ -113,6 +113,8 @@ def benchmark_flash_decoding_attention(
kv_max_split_num = (max_seq_len_across_batch + BLOCK_SIZE - 1) // BLOCK_SIZE kv_max_split_num = (max_seq_len_across_batch + BLOCK_SIZE - 1) // BLOCK_SIZE
output = torch.empty((BATCH_SIZE, NUM_ATTN_HEADS, HEAD_SIZE), dtype=dtype, device=device) output = torch.empty((BATCH_SIZE, NUM_ATTN_HEADS, HEAD_SIZE), dtype=dtype, device=device)
sm_scale = 1.0 / (HEAD_SIZE**0.5) sm_scale = 1.0 / (HEAD_SIZE**0.5)
alibi_slopes = None
kv_scale = 1.0
mid_output = torch.empty( mid_output = torch.empty(
size=(BATCH_SIZE, NUM_ATTN_HEADS, kv_max_split_num, HEAD_SIZE), dtype=torch.float32, device=device size=(BATCH_SIZE, NUM_ATTN_HEADS, kv_max_split_num, HEAD_SIZE), dtype=torch.float32, device=device
@ -136,6 +138,7 @@ def benchmark_flash_decoding_attention(
max_seq_len_across_batch, max_seq_len_across_batch,
alibi_slopes, alibi_slopes,
"auto", "auto",
kv_scale,
) )
elif provider == "triton_flash_decoding_attention": elif provider == "triton_flash_decoding_attention":
fn = lambda: flash_decoding_attention( fn = lambda: flash_decoding_attention(
@ -164,6 +167,7 @@ def benchmark_flash_decoding_attention(
max_seq_len_across_batch, max_seq_len_across_batch,
mid_output, mid_output,
mid_output_lse, mid_output_lse,
alibi_slopes,
sm_scale, sm_scale,
) )
else: else:

View File

@ -2,7 +2,11 @@ import torch
from colossalai.kernel.kernel_loader import InferenceOpsLoader from colossalai.kernel.kernel_loader import InferenceOpsLoader
from colossalai.kernel.triton import copy_kv_to_blocked_cache, decoding_fused_rotary_embedding, rotary_embedding from colossalai.kernel.triton import copy_kv_to_blocked_cache, decoding_fused_rotary_embedding, rotary_embedding
from tests.test_infer.test_ops.triton.kernel_utils import mock_alloc_block_table_and_kvcache_v2, mock_alloc_single_token from tests.test_infer.test_ops.triton.kernel_utils import (
mock_alloc_block_table_and_kvcache_v2,
mock_alloc_block_table_and_kvcache_v3,
mock_alloc_single_token,
)
inference_ops = InferenceOpsLoader().load() inference_ops = InferenceOpsLoader().load()
@ -68,11 +72,17 @@ def benchmark_rotary_emb(
cache_shape = (BATCH_SIZE * max_num_blocks_per_seq, num_kv_heads, block_size, head_dim) cache_shape = (BATCH_SIZE * max_num_blocks_per_seq, num_kv_heads, block_size, head_dim)
k_cache = torch.zeros(size=cache_shape, dtype=dtype, device="cuda") k_cache = torch.zeros(size=cache_shape, dtype=dtype, device="cuda")
v_cache = torch.zeros(size=cache_shape, dtype=dtype, device="cuda") v_cache = torch.zeros(size=cache_shape, dtype=dtype, device="cuda")
x = 16 // torch.tensor([], dtype=dtype).element_size()
new_cache_shape = (BATCH_SIZE * max_num_blocks_per_seq, num_kv_heads, head_dim // x, block_size, x)
new_k_cache = torch.zeros(size=new_cache_shape, dtype=dtype, device="cuda")
past_kv_seq_lengths = torch.tensor([SEQ_LEN - 1 for _ in range(BATCH_SIZE)], dtype=torch.int32, device="cuda") past_kv_seq_lengths = torch.tensor([SEQ_LEN - 1 for _ in range(BATCH_SIZE)], dtype=torch.int32, device="cuda")
block_tables = mock_alloc_block_table_and_kvcache_v2( block_tables = mock_alloc_block_table_and_kvcache_v2(
k, v, k_cache, v_cache, past_kv_seq_lengths, BATCH_SIZE, max_num_blocks_per_seq, block_size k, v, k_cache, v_cache, past_kv_seq_lengths, BATCH_SIZE, max_num_blocks_per_seq, block_size
) )
_ = mock_alloc_block_table_and_kvcache_v3(
k, v, new_k_cache, v_cache, past_kv_seq_lengths, BATCH_SIZE, max_num_blocks_per_seq, block_size
)
new_k = torch.randn((BATCH_SIZE, num_kv_heads, head_dim), dtype=dtype, device="cuda") new_k = torch.randn((BATCH_SIZE, num_kv_heads, head_dim), dtype=dtype, device="cuda")
new_q = torch.randn_like(new_k) new_q = torch.randn_like(new_k)
new_v = torch.randn_like(new_k) new_v = torch.randn_like(new_k)
@ -94,12 +104,12 @@ def benchmark_rotary_emb(
) )
elif provider == "no_fused_cuda_rotary_emb_func": elif provider == "no_fused_cuda_rotary_emb_func":
fn = lambda: [ fn = lambda: [
inference_ops.rotary_embedding(new_q, new_k, cos, sin), inference_ops.rotary_embedding(new_q, new_k, cos, sin, True),
inference_ops.decode_kv_cache_memcpy(new_k, new_v, k_cache, v_cache, kv_seq_lengths, block_tables), inference_ops.decode_kv_cache_memcpy(new_k, new_v, new_k_cache, v_cache, kv_seq_lengths, block_tables),
] ]
elif provider == "fused_cuda_rotary_emb_func": elif provider == "fused_cuda_rotary_emb_func":
fn = lambda: inference_ops.rotary_embedding_and_cache_copy( fn = lambda: inference_ops.rotary_embedding_and_cache_copy(
new_q, new_k, new_v, cos, sin, k_cache, v_cache, kv_seq_lengths, block_tables new_q, new_k, new_v, cos, sin, new_k_cache, v_cache, kv_seq_lengths, block_tables, True
) )
else: else:
raise ValueError("Undefined provider") raise ValueError("Undefined provider")

View File

@ -4,6 +4,7 @@ from colossalai.inference.modeling.layers.attention import copy_to_cache
from colossalai.kernel.kernel_loader import InferenceOpsLoader from colossalai.kernel.kernel_loader import InferenceOpsLoader
from colossalai.kernel.triton import copy_kv_to_blocked_cache from colossalai.kernel.triton import copy_kv_to_blocked_cache
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from tests.test_infer.test_ops.cuda.test_kv_cache_memcpy import prepare_data as prepare_data_new_kcache_layout
from tests.test_infer.test_ops.triton.test_kvcache_copy import prepare_data from tests.test_infer.test_ops.triton.test_kvcache_copy import prepare_data
try: try:
@ -68,6 +69,9 @@ def benchmark_kvcache_copy(
elif provider == "triton_copy_func": elif provider == "triton_copy_func":
fn = lambda: copy_kv_to_blocked_cache(new_k, new_v, k_cache, v_cache, context_lengths, block_tables) fn = lambda: copy_kv_to_blocked_cache(new_k, new_v, k_cache, v_cache, context_lengths, block_tables)
elif provider == "cuda_copy_func": elif provider == "cuda_copy_func":
_, _, k_cache, _, _, _, _, _, _ = prepare_data_new_kcache_layout(
bsz, num_kv_heads, block_size, max_seq_len // block_size, context_lengths - 1, device, dtype
)
new_k = new_k.squeeze(1) if new_k.dim() == 4 else new_k new_k = new_k.squeeze(1) if new_k.dim() == 4 else new_k
new_v = new_v.squeeze(1) if new_v.dim() == 4 else new_v new_v = new_v.squeeze(1) if new_v.dim() == 4 else new_v
fn = lambda: inference_ops.decode_kv_cache_memcpy(new_k, new_v, k_cache, v_cache, context_lengths, block_tables) fn = lambda: inference_ops.decode_kv_cache_memcpy(new_k, new_v, k_cache, v_cache, context_lengths, block_tables)

View File

@ -24,7 +24,8 @@ __global__ void context_kv_cache_memcpy_kernel(
const int batch_size, const int batch_size,
const int block_table_stride, const int block_table_stride,
const int64_t key_stride, const int64_t key_stride,
const int64_t value_stride const int64_t value_stride,
const int x
) )
{ {
const int seq_token_id = blockIdx.x; const int seq_token_id = blockIdx.x;
@ -40,23 +41,33 @@ __global__ void context_kv_cache_memcpy_kernel(
const int total_token_id = cu_seqlens[seq_id] + seq_token_id; const int total_token_id = cu_seqlens[seq_id] + seq_token_id;
int head_id; int head_id;
int head_offset; int head_offset;
int x_id;
int x_offset;
int64_t key_src_id; int64_t key_src_id;
int64_t value_src_id; int64_t value_src_id;
int64_t target_id; int64_t target_key_id;
int64_t target_value_id;
int i = threadIdx.x * VecSize; int i = threadIdx.x * VecSize;
for (; i <= (hidden_size - VecSize); i += blockDim.x * VecSize) { for (; i <= (hidden_size - VecSize); i += blockDim.x * VecSize) {
head_id = i / head_dim; head_id = i / head_dim;
head_offset = i % head_dim; head_offset = i % head_dim;
x_id = head_offset / x;
x_offset = head_offset % x;
key_src_id = total_token_id * key_stride + i; key_src_id = total_token_id * key_stride + i;
value_src_id = total_token_id * value_stride + i; value_src_id = total_token_id * value_stride + i;
target_id = block_id * hidden_size * block_size target_key_id = block_id * hidden_size * block_size
+ head_id * block_size * head_dim
+ x_id * block_size * x
+ block_offset * x
+ x_offset;
target_value_id = block_id * hidden_size * block_size
+ head_id * block_size * head_dim + head_id * block_size * head_dim
+ block_offset * head_dim + head_offset; + block_offset * head_dim + head_offset;
copy<T, CacheT, VecSize>(key + key_src_id, key_cache + target_id); copy<T, CacheT, VecSize>(key + key_src_id, key_cache + target_key_id);
copy<T, CacheT, VecSize>(value + value_src_id, value_cache + target_id); copy<T, CacheT, VecSize>(value + value_src_id, value_cache + target_value_id);
} }
// tail process // tail process
@ -64,14 +75,21 @@ __global__ void context_kv_cache_memcpy_kernel(
for (; i < hidden_size; ++i ) { for (; i < hidden_size; ++i ) {
head_id = i / head_dim; head_id = i / head_dim;
head_offset = i % head_dim; head_offset = i % head_dim;
x_id = head_offset / x;
x_offset = head_offset % x;
key_src_id = total_token_id * key_stride + i; key_src_id = total_token_id * key_stride + i;
value_src_id = total_token_id * value_stride + i; value_src_id = total_token_id * value_stride + i;
target_id = block_id * hidden_size * block_size target_key_id = block_id * hidden_size * block_size
+ head_id * block_size * head_dim
+ x_id * block_size * x
+ block_offset * x
+ x_offset;
target_value_id = block_id * hidden_size * block_size
+ head_id * block_size * head_dim + head_id * block_size * head_dim
+ block_offset * head_dim + head_offset; + block_offset * head_dim + head_offset;
key_cache[target_id] = CastFunctor<T, CacheT>()(key[key_src_id]); key_cache[target_key_id] = CastFunctor<T, CacheT>()(key[key_src_id]);
value_cache[target_id] = CastFunctor<T, CacheT>()(value[value_src_id]); value_cache[target_value_id] = CastFunctor<T, CacheT>()(value[value_src_id]);
} }
} }
@ -81,7 +99,7 @@ template<typename T, typename CacheT>
void apply_context_kv_cache_memcpy( void apply_context_kv_cache_memcpy(
torch::Tensor& key, // [num_tokens, head_num, head_dim] torch::Tensor& key, // [num_tokens, head_num, head_dim]
torch::Tensor& value, // [num_tokens, head_num, head_dim] torch::Tensor& value, // [num_tokens, head_num, head_dim]
torch::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim] torch::Tensor& key_cache, // [num_blocks, head_num, head_dim/x, block_size, x]
torch::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim] torch::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim]
torch::Tensor& sequence_lengths, // [batch_size] torch::Tensor& sequence_lengths, // [batch_size]
torch::Tensor& cu_seqlens, // [batch_size + 1] torch::Tensor& cu_seqlens, // [batch_size + 1]
@ -91,7 +109,8 @@ void apply_context_kv_cache_memcpy(
int num_tokens = key.size(0); int num_tokens = key.size(0);
int head_num = key.size(1); int head_num = key.size(1);
int head_dim = key.size(2); int head_dim = key.size(2);
int block_size = key_cache.size(2); int block_size = key_cache.size(3);
int x = key_cache.size(4);
int batch_size = block_tables.size(0); int batch_size = block_tables.size(0);
int64_t key_stride = key.stride(0); int64_t key_stride = key.stride(0);
@ -127,7 +146,8 @@ void apply_context_kv_cache_memcpy(
batch_size, \ batch_size, \
block_table_stride, \ block_table_stride, \
key_stride, \ key_stride, \
value_stride \ value_stride, \
x \
); \ ); \
} while(0) } while(0)
@ -164,7 +184,7 @@ void apply_context_kv_cache_memcpy(
void context_kv_cache_memcpy( void context_kv_cache_memcpy(
torch::Tensor& key, // [num_tokens, head_num, head_dim] torch::Tensor& key, // [num_tokens, head_num, head_dim]
torch::Tensor& value, // [num_tokens, head_num, head_dim] torch::Tensor& value, // [num_tokens, head_num, head_dim]
torch::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim] torch::Tensor& key_cache, // [num_blocks, head_num, head_dim/x, block_size, x]
torch::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim] torch::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim]
torch::Tensor& sequence_lengths, // [batch_size] torch::Tensor& sequence_lengths, // [batch_size]
torch::Tensor& cu_seqlens, // [batch_size + 1] torch::Tensor& cu_seqlens, // [batch_size + 1]

View File

@ -20,7 +20,8 @@ __global__ void decode_kv_cache_memcpy_kernel(
const int block_size, const int block_size,
const int64_t key_stride, const int64_t key_stride,
const int64_t value_stride, const int64_t value_stride,
const int block_table_stride const int block_table_stride,
const int x
) )
{ {
const int seq_id = blockIdx.x; const int seq_id = blockIdx.x;
@ -38,28 +39,42 @@ __global__ void decode_kv_cache_memcpy_kernel(
for (; i <= (hidden_size - VecSize); i += blockDim.x * VecSize) { for (; i <= (hidden_size - VecSize); i += blockDim.x * VecSize) {
const int head_id = i / head_dim; const int head_id = i / head_dim;
const int head_offset = i % head_dim; const int head_offset = i % head_dim;
const int x_id = head_offset / x;
const int x_offset = head_offset % x;
const int64_t key_src_id = seq_id * key_stride + i; const int64_t key_src_id = seq_id * key_stride + i;
const int64_t value_src_id = seq_id * value_stride + i; const int64_t value_src_id = seq_id * value_stride + i;
const int64_t target_id = block_id * hidden_size * block_size const int64_t target_key_id = block_id * hidden_size * block_size
+ head_id * block_size * head_dim
+ x_id * block_size * x
+ block_offset * x
+ x_offset;
const int64_t target_value_id = block_id * hidden_size * block_size
+ head_id * block_size * head_dim + head_id * block_size * head_dim
+ block_offset * head_dim + head_offset; + block_offset * head_dim + head_offset;
copy_vector<scalar_t, VecSize>(key_cache + target_id, key + key_src_id); copy_vector<scalar_t, VecSize>(key_cache + target_key_id, key + key_src_id);
copy_vector<scalar_t, VecSize>(value_cache + target_id, value + value_src_id); copy_vector<scalar_t, VecSize>(value_cache + target_value_id, value + value_src_id);
} }
if (!Aligned) { if (!Aligned) {
for (; i < hidden_size; ++i ) { for (; i < hidden_size; ++i ) {
const int head_id = i / head_dim; const int head_id = i / head_dim;
const int head_offset = i % head_dim; const int head_offset = i % head_dim;
const int x_id = head_offset / x;
const int x_offset = head_offset % x;
const int64_t key_src_id = seq_id * key_stride + i; const int64_t key_src_id = seq_id * key_stride + i;
const int64_t value_src_id = seq_id * value_stride + i; const int64_t value_src_id = seq_id * value_stride + i;
const int64_t target_id = block_id * hidden_size * block_size const int64_t target_key_id = block_id * hidden_size * block_size
+ head_id * block_size * head_dim
+ x_id * block_size * x
+ block_offset * x
+ x_offset;
const int64_t target_value_id = block_id * hidden_size * block_size
+ head_id * block_size * head_dim + head_id * block_size * head_dim
+ block_offset * head_dim + head_offset; + block_offset * head_dim + head_offset;
key_cache[target_id] = key[key_src_id]; key_cache[target_key_id] = key[key_src_id];
value_cache[target_id] = value[value_src_id]; value_cache[target_value_id] = value[value_src_id];
} }
} }
@ -69,7 +84,7 @@ template<typename scalar_t>
void apply_decode_kv_cache_memcpy( void apply_decode_kv_cache_memcpy(
at::Tensor& key, // [num_tokens, head_num, head_dim] at::Tensor& key, // [num_tokens, head_num, head_dim]
at::Tensor& value, // [num_tokens, head_num, head_dim] at::Tensor& value, // [num_tokens, head_num, head_dim]
at::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim] at::Tensor& key_cache, // [num_blocks, head_num, head_dim/x, block_size, x]
at::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim] at::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim]
at::Tensor& sequence_lengths, // [batch_size] at::Tensor& sequence_lengths, // [batch_size]
at::Tensor& block_tables) // [batch_size, max_seq_len] at::Tensor& block_tables) // [batch_size, max_seq_len]
@ -77,7 +92,8 @@ void apply_decode_kv_cache_memcpy(
int num_tokens = key.size(0); int num_tokens = key.size(0);
int head_num = key.size(1); int head_num = key.size(1);
int head_dim = key.size(2); int head_dim = key.size(2);
int block_size = key_cache.size(2); int block_size = key_cache.size(3);
int x = key_cache.size(4);
int64_t key_stride = key.stride(0); int64_t key_stride = key.stride(0);
int64_t value_stride = value.stride(0); int64_t value_stride = value.stride(0);
@ -110,7 +126,8 @@ void apply_decode_kv_cache_memcpy(
block_size, \ block_size, \
key_stride, \ key_stride, \
value_stride, \ value_stride, \
block_table_stride \ block_table_stride, \
x \
); \ ); \
} while(0) } while(0)
@ -146,7 +163,7 @@ void apply_decode_kv_cache_memcpy(
void decode_kv_cache_memcpy( void decode_kv_cache_memcpy(
at::Tensor& key, // [num_tokens, head_num, head_dim] at::Tensor& key, // [num_tokens, head_num, head_dim]
at::Tensor& value, // [num_tokens, head_num, head_dim] at::Tensor& value, // [num_tokens, head_num, head_dim]
at::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim] at::Tensor& key_cache, // [num_blocks, head_num, head_dim/x, block_size, x]
at::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim] at::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim]
at::Tensor& sequence_lengths, // [batch_size] at::Tensor& sequence_lengths, // [batch_size]
at::Tensor& block_tables) // [batch_size, max_seq_len] at::Tensor& block_tables) // [batch_size, max_seq_len]

View File

@ -67,6 +67,7 @@ __global__ void flash_decoding_attention_kernel(
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, block_size, head_size] const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, block_size, head_size]
const int* __restrict__ context_lens, // [num_tokens] const int* __restrict__ context_lens, // [num_tokens]
const int* __restrict__ block_tables, // [num_tokens, max_num_blocks_per_seq] const int* __restrict__ block_tables, // [num_tokens, max_num_blocks_per_seq]
const float* __restrict__ alibi_slopes, // [num_heads]
const int max_seq_len, const int max_seq_len,
const int num_kv_heads, const int num_kv_heads,
const float scale, const float scale,
@ -105,6 +106,7 @@ __global__ void flash_decoding_attention_kernel(
using FloatVecT = typename FloatVecTypeTrait<LVecT>::Type; using FloatVecT = typename FloatVecTypeTrait<LVecT>::Type;
const int context_len = context_lens[seq_idx]; const int context_len = context_lens[seq_idx];
const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx];
const int thread_group_offset = lane % NUM_THREADS_PER_X; const int thread_group_offset = lane % NUM_THREADS_PER_X;
const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE); const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE);
const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq; const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
@ -164,6 +166,7 @@ __global__ void flash_decoding_attention_kernel(
if (thread_group_offset == 0 && lane < NUM_ROWS_PER_ROUNDS * NUM_THREADS_PER_X) { if (thread_group_offset == 0 && lane < NUM_ROWS_PER_ROUNDS * NUM_THREADS_PER_X) {
const int token_idx = block_idx * BLOCK_SIZE + i * NUM_ROWS_PER_ROUNDS + lane / NUM_THREADS_PER_X; const int token_idx = block_idx * BLOCK_SIZE + i * NUM_ROWS_PER_ROUNDS + lane / NUM_THREADS_PER_X;
qk += (alibi_slope != 0) ? alibi_slope * (token_idx - context_len + 1) : 0;
const bool mask = token_idx >= context_len; const bool mask = token_idx >= context_len;
logits[token_idx] = mask ? 0.f : qk; logits[token_idx] = mask ? 0.f : qk;
qk_max = mask ? qk_max : fmaxf(qk_max, qk); qk_max = mask ? qk_max : fmaxf(qk_max, qk);
@ -261,6 +264,7 @@ __global__ void flash_decoding_attention_kernel(
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \ reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
context_lens.data_ptr<int>(), \ context_lens.data_ptr<int>(), \
block_tables.data_ptr<int>(), \ block_tables.data_ptr<int>(), \
alibi_slopes_ptr, \
max_context_len, \ max_context_len, \
num_kv_heads, \ num_kv_heads, \
scale, \ scale, \
@ -282,7 +286,8 @@ void flash_decoding_attention_v1_launcher(
torch::Tensor& context_lens, // [num_tokens] torch::Tensor& context_lens, // [num_tokens]
torch::Tensor& block_tables, // [num_tokens, max_num_blocks_per_seq] torch::Tensor& block_tables, // [num_tokens, max_num_blocks_per_seq]
int max_context_len, int max_context_len,
float scale) { float scale,
const c10::optional<torch::Tensor>& alibi_slopes) {
int num_tokens = query.size(0); int num_tokens = query.size(0);
int num_heads = query.size(1); int num_heads = query.size(1);
int head_size = query.size(2); int head_size = query.size(2);
@ -304,6 +309,10 @@ void flash_decoding_attention_v1_launcher(
// Keep that in sync with the logic here! // Keep that in sync with the logic here!
int shared_mem_size = std::max(logits_size, outputs_size) + DIVIDE_ROUND_UP(max_num_blocks_per_seq * sizeof(int), sizeof(float4)) * sizeof(float4); int shared_mem_size = std::max(logits_size, outputs_size) + DIVIDE_ROUND_UP(max_num_blocks_per_seq * sizeof(int), sizeof(float4)) * sizeof(float4);
const float* alibi_slopes_ptr = alibi_slopes ?
reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
: nullptr;
dim3 grid(num_heads, num_tokens, 1); dim3 grid(num_heads, num_tokens, 1);
dim3 block(NUM_THREADS); dim3 block(NUM_THREADS);
const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
@ -336,7 +345,8 @@ void flash_decoding_attention_v1_launcher(
context_lens, \ context_lens, \
block_tables, \ block_tables, \
max_context_len, \ max_context_len, \
scale); scale, \
alibi_slopes);
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes // NOTE(woosuk): To reduce the compilation time, we omitted block sizes
// 1, 2, 4, 64, 128, 256. // 1, 2, 4, 64, 128, 256.
@ -367,6 +377,7 @@ void flash_decoding_attention(
int max_context_len, int max_context_len,
torch::Tensor& tmp_out, // [num_tokens, num_heads, max_num_partitions, head_size] torch::Tensor& tmp_out, // [num_tokens, num_heads, max_num_partitions, head_size]
torch::Tensor& tmp_out_lse, // [num_tokens, num_heads, max_num_partitions] torch::Tensor& tmp_out_lse, // [num_tokens, num_heads, max_num_partitions]
const c10::optional<torch::Tensor>& alibi_slopes,
float scale) { float scale) {

View File

@ -91,7 +91,7 @@ __device__ void apply_k_rotary_emb_compute(
const int* __restrict__ block_tables, const int64_t key_stride, const int* __restrict__ block_tables, const int64_t key_stride,
const int64_t value_stride, const int token_id, const int64_t value_stride, const int token_id,
const int block_table_stride, const int head_num, const int head_dim, const int block_table_stride, const int head_num, const int head_dim,
const int kv_head_num, const int block_size, const int half_head_dim, const int kv_head_num, const int block_size, const int x, const int half_head_dim,
const int shard_block_size) { const int shard_block_size) {
const int seq_len = sequence_lengths[token_id] - 1; const int seq_len = sequence_lengths[token_id] - 1;
const int block_offset = seq_len % block_size; const int block_offset = seq_len % block_size;
@ -102,36 +102,40 @@ __device__ void apply_k_rotary_emb_compute(
return; return;
} }
scalar_t x[VecSize]; scalar_t x0[VecSize];
scalar_t y[VecSize]; scalar_t x1[VecSize];
scalar_t out_x[VecSize]; scalar_t out_x[VecSize];
scalar_t out_y[VecSize]; scalar_t out_y[VecSize];
for (int i = threadIdx.x * VecSize; i < kv_head_num * half_head_dim; for (int i = threadIdx.x * VecSize; i < kv_head_num * half_head_dim;
i += blockDim.x * VecSize) { i += blockDim.x * VecSize) {
const int head_offset = i % half_head_dim; const int half_head_offset = i % half_head_dim;
const int x_id = half_head_offset / x;
const int x_offset = half_head_offset % x;
const int shard_offset = const int shard_offset =
(head_offset / shard_block_size) * shard_block_size + (half_head_offset / shard_block_size) * shard_block_size +
(head_offset % shard_block_size) / VecSize; (half_head_offset % shard_block_size) / VecSize;
const int64_t addr_offset = const int64_t addr_offset =
token_id * key_stride + (i / half_head_dim) * head_dim + head_offset; token_id * key_stride + (i / half_head_dim) * head_dim + half_head_offset;
const int64_t target_id = block_id * kv_head_num * head_dim * block_size + const int64_t target_id = block_id * kv_head_num * head_dim * block_size
(i / half_head_dim) * block_size * head_dim + + (i / half_head_dim) * block_size * head_dim
block_offset * head_dim + head_offset; + x_id * block_size * x
+ block_offset * x
+ x_offset;
copy_vector<scalar_t, VecSize>(x, key + addr_offset); copy_vector<scalar_t, VecSize>(x0, key + addr_offset);
copy_vector<scalar_t, VecSize>(y, key + addr_offset + half_head_dim); copy_vector<scalar_t, VecSize>(x1, key + addr_offset + half_head_dim);
#pragma unroll #pragma unroll
for (int j = 0; j < VecSize; j++) { for (int j = 0; j < VecSize; j++) {
out_x[j] = static_cast<scalar_t>(static_cast<m_scalar_t>(x[j]) * cos_ptr[j * 32 + shard_offset] - out_x[j] = static_cast<scalar_t>(static_cast<m_scalar_t>(x0[j]) * cos_ptr[j * 32 + shard_offset] -
static_cast<m_scalar_t>(y[j]) * sin_ptr[j * 32 + shard_offset]); static_cast<m_scalar_t>(x1[j]) * sin_ptr[j * 32 + shard_offset]);
out_y[j] = static_cast<scalar_t>(static_cast<m_scalar_t>(y[j]) * cos_ptr[j * 32 + shard_offset] + out_y[j] = static_cast<scalar_t>(static_cast<m_scalar_t>(x1[j]) * cos_ptr[j * 32 + shard_offset] +
static_cast<m_scalar_t>(x[j]) * sin_ptr[j * 32 + shard_offset]); static_cast<m_scalar_t>(x0[j]) * sin_ptr[j * 32 + shard_offset]);
} }
copy_vector<scalar_t, VecSize>(key_cache + target_id, out_x); copy_vector<scalar_t, VecSize>(key_cache + target_id, out_x);
copy_vector<scalar_t, VecSize>(key_cache + target_id + half_head_dim, copy_vector<scalar_t, VecSize>(key_cache + target_id + half_head_dim * block_size,
out_y); out_y);
} }
@ -162,7 +166,8 @@ __global__ void rotary_embedding_and_cache_copy_kernel(
const int head_num, const int head_num,
const int head_dim, const int head_dim,
const int kv_head_num, const int kv_head_num,
const int block_size const int block_size,
const int x
) { ) {
const int token_id = blockIdx.x; const int token_id = blockIdx.x;
@ -182,7 +187,7 @@ __global__ void rotary_embedding_and_cache_copy_kernel(
apply_emb_rotary_compute<scalar_t, m_scalar_t, VecSize>(query, cos_ptr, sin_ptr, query_stride, token_id, shard_block_size, half_head_dim, head_num, head_dim); apply_emb_rotary_compute<scalar_t, m_scalar_t, VecSize>(query, cos_ptr, sin_ptr, query_stride, token_id, shard_block_size, half_head_dim, head_num, head_dim);
//compute key and copy kv //compute key and copy kv
apply_k_rotary_emb_compute<scalar_t, m_scalar_t, VecSize>(key, value, key_cache, value_cache, cos_ptr, sin_ptr, sequence_lengths, block_tables, key_stride, value_stride, token_id, block_table_stride, head_num, head_dim, kv_head_num, block_size, half_head_dim, shard_block_size); apply_k_rotary_emb_compute<scalar_t, m_scalar_t, VecSize>(key, value, key_cache, value_cache, cos_ptr, sin_ptr, sequence_lengths, block_tables, key_stride, value_stride, token_id, block_table_stride, head_num, head_dim, kv_head_num, block_size, x, half_head_dim, shard_block_size);
} }
template<typename scalar_t, typename m_scalar_t, int VecSize> template<typename scalar_t, typename m_scalar_t, int VecSize>
@ -220,6 +225,31 @@ __global__ void rotary_embedding_kernel(
apply_emb_rotary_compute<scalar_t, m_scalar_t, VecSize>(key, cos_ptr, sin_ptr, key_stride, token_id, shard_block_size, half_head_dim, kv_head_num, head_dim); apply_emb_rotary_compute<scalar_t, m_scalar_t, VecSize>(key, cos_ptr, sin_ptr, key_stride, token_id, shard_block_size, half_head_dim, kv_head_num, head_dim);
} }
#define ROTARY_EMBEDDING_AND_CACHE_COPY_LAUNCHER(VEC_SIZE) \
rotary_embedding_and_cache_copy_kernel<scalar_t, m_scalar_t, VEC_SIZE><<<grid, block, shared_memory_size, stream>>>( \
query.data_ptr<scalar_t>(), \
key.data_ptr<scalar_t>(), \
value.data_ptr<scalar_t>(), \
cos.data_ptr<scalar_t>(), \
sin.data_ptr<scalar_t>(), \
key_cache.data_ptr<scalar_t>(), \
value_cache.data_ptr<scalar_t>(), \
sequence_lengths.data_ptr<int>(), \
block_tables.data_ptr<int>(), \
query_stride, \
key_stride, \
value_stride, \
shard_element_num / 2, \
cos_stride, \
sin_stride, \
block_table_stride, \
head_num, \
head_dim, \
kv_head_num, \
block_size, \
x); \
template<typename scalar_t, bool high_precision> template<typename scalar_t, bool high_precision>
void apply_rotary_embedding_and_cache_copy( void apply_rotary_embedding_and_cache_copy(
at::Tensor& query, // [num_tokens, head_num, head_dim] at::Tensor& query, // [num_tokens, head_num, head_dim]
@ -227,7 +257,7 @@ void apply_rotary_embedding_and_cache_copy(
at::Tensor& value, // [num_tokens, kv_head_num, head_dim] at::Tensor& value, // [num_tokens, kv_head_num, head_dim]
at::Tensor& cos, // [num_tokens, head_dim] at::Tensor& cos, // [num_tokens, head_dim]
at::Tensor& sin, // [num_tokens, head_dim] at::Tensor& sin, // [num_tokens, head_dim]
at::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim] at::Tensor& key_cache, // [num_blocks, head_num, head_dim/x, block_size, x]
at::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim] at::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim]
at::Tensor& sequence_lengths, // [batch_size] at::Tensor& sequence_lengths, // [batch_size]
at::Tensor& block_tables) // [batch_size, max_seq_len] at::Tensor& block_tables) // [batch_size, max_seq_len]
@ -236,7 +266,8 @@ void apply_rotary_embedding_and_cache_copy(
int head_num = query.size(1); int head_num = query.size(1);
int head_dim = query.size(2); int head_dim = query.size(2);
int kv_head_num = key.size(1); int kv_head_num = key.size(1);
int block_size = key_cache.size(2); int block_size = key_cache.size(3);
int x = key_cache.size(4);
int64_t query_stride = query.stride(0); int64_t query_stride = query.stride(0);
int64_t key_stride = key.stride(0); int64_t key_stride = key.stride(0);
@ -262,79 +293,17 @@ void apply_rotary_embedding_and_cache_copy(
dim3 grid(num_tokens); dim3 grid(num_tokens);
dim3 block(std::min(thread_nums, 512)); dim3 block(std::min(thread_nums, 512));
int64_t shard_element_num = ((head_dim + shard_block_size - 1) / shard_block_size) * shard_block_size; int64_t shard_element_num = ((head_dim + shard_block_size - 1) / shard_block_size) * shard_block_size;
const int shared_memory_size = shard_element_num * sizeof(m_scalar_t);
switch (vec_size) { switch (vec_size) {
case 1: case 1:
rotary_embedding_and_cache_copy_kernel<scalar_t, m_scalar_t, 1><<<grid, block, shard_element_num * sizeof(m_scalar_t), stream>>>( ROTARY_EMBEDDING_AND_CACHE_COPY_LAUNCHER(1);
query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(),
value.data_ptr<scalar_t>(),
cos.data_ptr<scalar_t>(),
sin.data_ptr<scalar_t>(),
key_cache.data_ptr<scalar_t>(),
value_cache.data_ptr<scalar_t>(),
sequence_lengths.data_ptr<int>(),
block_tables.data_ptr<int>(),
query_stride,
key_stride,
value_stride,
shard_element_num / 2,
cos_stride,
sin_stride,
block_table_stride,
head_num,
head_dim,
kv_head_num,
block_size
);
break; break;
case 2: case 2:
rotary_embedding_and_cache_copy_kernel<scalar_t, m_scalar_t, 2><<<grid, block, shard_element_num * sizeof(m_scalar_t), stream>>>( ROTARY_EMBEDDING_AND_CACHE_COPY_LAUNCHER(2);
query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(),
value.data_ptr<scalar_t>(),
cos.data_ptr<scalar_t>(),
sin.data_ptr<scalar_t>(),
key_cache.data_ptr<scalar_t>(),
value_cache.data_ptr<scalar_t>(),
sequence_lengths.data_ptr<int>(),
block_tables.data_ptr<int>(),
query_stride,
key_stride,
value_stride,
shard_element_num / 2,
cos_stride,
sin_stride,
block_table_stride,
head_num,
head_dim,
kv_head_num,
block_size
);
break; break;
case 4: case 4:
rotary_embedding_and_cache_copy_kernel<scalar_t, m_scalar_t, 4><<<grid, block, shard_element_num * sizeof(m_scalar_t), stream>>>( ROTARY_EMBEDDING_AND_CACHE_COPY_LAUNCHER(4);
query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(),
value.data_ptr<scalar_t>(),
cos.data_ptr<scalar_t>(),
sin.data_ptr<scalar_t>(),
key_cache.data_ptr<scalar_t>(),
value_cache.data_ptr<scalar_t>(),
sequence_lengths.data_ptr<int>(),
block_tables.data_ptr<int>(),
query_stride,
key_stride,
value_stride,
shard_element_num / 2,
cos_stride,
sin_stride,
block_table_stride,
head_num,
head_dim,
kv_head_num,
block_size
);
break; break;
default: default:
AT_ERROR("Unsupported vectorized size ", vec_size); AT_ERROR("Unsupported vectorized size ", vec_size);
@ -441,7 +410,7 @@ void rotary_embedding_and_cache_copy(
at::Tensor& value, // [num_tokens, kv_head_num, head_dim] at::Tensor& value, // [num_tokens, kv_head_num, head_dim]
at::Tensor& cos, // [num_tokens, head_dim] at::Tensor& cos, // [num_tokens, head_dim]
at::Tensor& sin, // [num_tokens, head_dim] at::Tensor& sin, // [num_tokens, head_dim]
at::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim] at::Tensor& key_cache, // [num_blocks, head_num, head_dim/x, block_size, x]
at::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim] at::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim]
at::Tensor& sequence_lengths, // [batch_size] at::Tensor& sequence_lengths, // [batch_size]
at::Tensor& block_tables, // [batch_size, max_seq_len] at::Tensor& block_tables, // [batch_size, max_seq_len]

View File

@ -3,7 +3,8 @@
void decode_kv_cache_memcpy( void decode_kv_cache_memcpy(
torch::Tensor& key, // [num_tokens, num_heads, head_size] torch::Tensor& key, // [num_tokens, num_heads, head_size]
torch::Tensor& value, // [num_tokens, num_heads, head_size] torch::Tensor& value, // [num_tokens, num_heads, head_size]
torch::Tensor& key_cache, // [num_blocks, num_heads, block_size, head_size] torch::Tensor&
key_cache, // [num_blocks, head_num, head_dim/x, block_size, x]
torch::Tensor& torch::Tensor&
value_cache, // [num_blocks, num_heads, block_size, head_size] value_cache, // [num_blocks, num_heads, block_size, head_size]
torch::Tensor& sequence_lengths, // [batch_size] torch::Tensor& sequence_lengths, // [batch_size]
@ -12,7 +13,7 @@ void decode_kv_cache_memcpy(
void context_kv_cache_memcpy( void context_kv_cache_memcpy(
at::Tensor& key, // [num_tokens, head_num, head_dim] at::Tensor& key, // [num_tokens, head_num, head_dim]
at::Tensor& value, // [num_tokens, head_num, head_dim] at::Tensor& value, // [num_tokens, head_num, head_dim]
at::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim] at::Tensor& key_cache, // [num_blocks, head_num, head_dim/x, block_size, x]
at::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim] at::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim]
at::Tensor& sequence_lengths, // [batch_size] at::Tensor& sequence_lengths, // [batch_size]
at::Tensor& cu_seqlens, // [batch_size + 1] at::Tensor& cu_seqlens, // [batch_size + 1]
@ -32,7 +33,8 @@ void rotary_embedding_and_cache_copy(
torch::Tensor& value, // [num_tokens, num_heads, head_dim] torch::Tensor& value, // [num_tokens, num_heads, head_dim]
torch::Tensor& cos, // [num_tokens, head_dim] torch::Tensor& cos, // [num_tokens, head_dim]
torch::Tensor& sin, // [num_tokens, head_dim] torch::Tensor& sin, // [num_tokens, head_dim]
torch::Tensor& key_cache, // [num_blocks, num_heads, block_size, head_dim] torch::Tensor&
key_cache, // [num_blocks, head_num, head_dim/x, block_size, x]
torch::Tensor& torch::Tensor&
value_cache, // [num_blocks, num_heads, block_size, head_dim] value_cache, // [num_blocks, num_heads, block_size, head_dim]
torch::Tensor& sequence_lengths, // [batch_size] torch::Tensor& sequence_lengths, // [batch_size]
@ -71,7 +73,7 @@ void flash_decoding_attention(
torch::Tensor& torch::Tensor&
tmp_out, // [num_tokens, num_heads, max_num_partitions, head_size] tmp_out, // [num_tokens, num_heads, max_num_partitions, head_size]
torch::Tensor& tmp_out_lse, // [num_tokens, num_heads, max_num_partitions] torch::Tensor& tmp_out_lse, // [num_tokens, num_heads, max_num_partitions]
float scale); const c10::optional<torch::Tensor>& alibi_slopes, float scale);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("decode_kv_cache_memcpy", &decode_kv_cache_memcpy, m.def("decode_kv_cache_memcpy", &decode_kv_cache_memcpy,

View File

@ -4,8 +4,10 @@ import numpy as np
import pytest import pytest
import torch import torch
from colossalai.inference.modeling.models.nopadding_baichuan import get_alibi_slopes
from colossalai.kernel.kernel_loader import InferenceOpsLoader from colossalai.kernel.kernel_loader import InferenceOpsLoader
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from tests.test_infer.test_ops.triton.test_context_attn_unpad import generate_alibi_mask
inference_ops = InferenceOpsLoader().load() inference_ops = InferenceOpsLoader().load()
@ -60,8 +62,9 @@ def numpy_allclose(x, y, rtol, atol):
@pytest.mark.parametrize("NUM_ATTN_HEADS", [16]) @pytest.mark.parametrize("NUM_ATTN_HEADS", [16])
@pytest.mark.parametrize("KV_GROUP_NUM", [1, 2, 16]) @pytest.mark.parametrize("KV_GROUP_NUM", [1, 2, 16])
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32]) @pytest.mark.parametrize("dtype", [torch.float16, torch.float32])
@pytest.mark.parametrize("use_alibi_slopes", [True, False])
def test_flash_decoding_attention( def test_flash_decoding_attention(
BATCH_SIZE, BLOCK_SIZE, MAX_NUM_BLOCKS_PER_SEQ, HEAD_SIZE, NUM_ATTN_HEADS, KV_GROUP_NUM, dtype BATCH_SIZE, BLOCK_SIZE, MAX_NUM_BLOCKS_PER_SEQ, HEAD_SIZE, NUM_ATTN_HEADS, KV_GROUP_NUM, dtype, use_alibi_slopes
): ):
torch.manual_seed(123) torch.manual_seed(123)
torch.cuda.empty_cache() torch.cuda.empty_cache()
@ -73,6 +76,11 @@ def test_flash_decoding_attention(
MAX_SEQ_LEN = BLOCK_SIZE * MAX_NUM_BLOCKS_PER_SEQ MAX_SEQ_LEN = BLOCK_SIZE * MAX_NUM_BLOCKS_PER_SEQ
device = get_current_device() device = get_current_device()
if use_alibi_slopes:
alibi_slopes = get_alibi_slopes(NUM_ATTN_HEADS, device)
else:
alibi_slopes = None
q, k_unpad, v_unpad, kv_seq_lengths = prepare_data( q, k_unpad, v_unpad, kv_seq_lengths = prepare_data(
BATCH_SIZE, HEAD_SIZE, NUM_ATTN_HEADS, NUM_KV_HEADS, MAX_SEQ_LEN, dtype, device BATCH_SIZE, HEAD_SIZE, NUM_ATTN_HEADS, NUM_KV_HEADS, MAX_SEQ_LEN, dtype, device
) )
@ -91,6 +99,15 @@ def test_flash_decoding_attention(
v_torch = convert_kv_unpad_to_padded(v_unpad, kv_seq_lengths, BATCH_SIZE, max_seq_len_across_batch) v_torch = convert_kv_unpad_to_padded(v_unpad, kv_seq_lengths, BATCH_SIZE, max_seq_len_across_batch)
torch_padding_mask = create_attention_mask(kv_seq_lengths, BATCH_SIZE, q_len, max_seq_len_across_batch, device) torch_padding_mask = create_attention_mask(kv_seq_lengths, BATCH_SIZE, q_len, max_seq_len_across_batch, device)
if use_alibi_slopes:
alibi_mask = generate_alibi_mask(alibi_slopes, NUM_ATTN_HEADS, max_seq_len_across_batch, device)
torch_padding_mask = torch_padding_mask + alibi_mask
if len(torch_padding_mask.size()) == 4:
torch_padding_mask = torch_padding_mask[:, :, -1:, :]
else:
torch_padding_mask = torch_padding_mask[:, -1:, :]
mid_output = torch.empty( mid_output = torch.empty(
size=(BATCH_SIZE, NUM_ATTN_HEADS, kv_max_split_num, HEAD_SIZE), dtype=torch.float32, device=device size=(BATCH_SIZE, NUM_ATTN_HEADS, kv_max_split_num, HEAD_SIZE), dtype=torch.float32, device=device
) )
@ -146,8 +163,14 @@ def test_flash_decoding_attention(
max_seq_len_across_batch, max_seq_len_across_batch,
mid_output, mid_output,
mid_output_lse, mid_output_lse,
alibi_slopes,
sm_scale, sm_scale,
) )
# The alibi may introduce relatively large errors
if use_alibi_slopes:
rtol = 1e0
numpy_allclose(out_ref, output, rtol=rtol, atol=atol) numpy_allclose(out_ref, output, rtol=rtol, atol=atol)
@ -168,8 +191,9 @@ except ImportError:
@pytest.mark.parametrize("NUM_ATTN_HEADS", [16]) @pytest.mark.parametrize("NUM_ATTN_HEADS", [16])
@pytest.mark.parametrize("KV_GROUP_NUM", [1, 2, 16]) @pytest.mark.parametrize("KV_GROUP_NUM", [1, 2, 16])
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32]) @pytest.mark.parametrize("dtype", [torch.float16, torch.float32])
@pytest.mark.parametrize("use_alibi_slopes", [True, False])
def test_vllm_flash_decoding_attention( def test_vllm_flash_decoding_attention(
BATCH_SIZE, BLOCK_SIZE, MAX_NUM_BLOCKS_PER_SEQ, HEAD_SIZE, NUM_ATTN_HEADS, KV_GROUP_NUM, dtype BATCH_SIZE, BLOCK_SIZE, MAX_NUM_BLOCKS_PER_SEQ, HEAD_SIZE, NUM_ATTN_HEADS, KV_GROUP_NUM, dtype, use_alibi_slopes
): ):
torch.manual_seed(123) torch.manual_seed(123)
torch.cuda.empty_cache() torch.cuda.empty_cache()
@ -199,6 +223,18 @@ def test_vllm_flash_decoding_attention(
v_torch = convert_kv_unpad_to_padded(v_unpad, kv_seq_lengths, BATCH_SIZE, max_seq_len_across_batch) v_torch = convert_kv_unpad_to_padded(v_unpad, kv_seq_lengths, BATCH_SIZE, max_seq_len_across_batch)
torch_padding_mask = create_attention_mask(kv_seq_lengths, BATCH_SIZE, q_len, max_seq_len_across_batch, device) torch_padding_mask = create_attention_mask(kv_seq_lengths, BATCH_SIZE, q_len, max_seq_len_across_batch, device)
if use_alibi_slopes:
alibi_slopes = get_alibi_slopes(NUM_ATTN_HEADS, device)
alibi_mask = generate_alibi_mask(alibi_slopes, NUM_ATTN_HEADS, max_seq_len_across_batch, device)
torch_padding_mask = torch_padding_mask + alibi_mask
if len(torch_padding_mask.size()) == 4:
torch_padding_mask = torch_padding_mask[:, :, -1:, :]
else:
torch_padding_mask = torch_padding_mask[:, -1:, :]
else:
alibi_slopes = None
if dtype == torch.float16: if dtype == torch.float16:
rtol = 1e-3 rtol = 1e-3
atol = 1e-3 atol = 1e-3
@ -236,8 +272,6 @@ def test_vllm_flash_decoding_attention(
HEAD_SIZE, HEAD_SIZE,
) )
alibi_slopes = None
vllm_ops.paged_attention_v1( vllm_ops.paged_attention_v1(
output, output,
q.squeeze(2), q.squeeze(2),
@ -253,6 +287,11 @@ def test_vllm_flash_decoding_attention(
"auto", "auto",
kv_scale, kv_scale,
) )
# The alibi may introduce relatively large errors
if use_alibi_slopes:
rtol = 1e0
numpy_allclose(out_ref, output, rtol=rtol, atol=atol) numpy_allclose(out_ref, output, rtol=rtol, atol=atol)
@ -277,5 +316,5 @@ if __name__ == "__main__":
dtype, dtype,
) in test_combinations: ) in test_combinations:
test_flash_decoding_attention( test_flash_decoding_attention(
batch_size, block_size, max_num_blocks_per_seq, head_size, num_attn_heads, kv_group_num, dtype batch_size, block_size, max_num_blocks_per_seq, head_size, num_attn_heads, kv_group_num, dtype, True
) )

View File

@ -4,12 +4,40 @@ import torch.nn.functional as F
from colossalai.kernel.kernel_loader import InferenceOpsLoader from colossalai.kernel.kernel_loader import InferenceOpsLoader
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from tests.test_infer.test_ops.triton.kernel_utils import generate_caches_and_block_tables_v2 from tests.test_infer.test_ops.triton.kernel_utils import generate_caches_and_block_tables_v3, mock_alloc_single_token
from tests.test_infer.test_ops.triton.test_kvcache_copy import prepare_data
inference_ops = InferenceOpsLoader().load() inference_ops = InferenceOpsLoader().load()
HEAD_DIM = 4 HEAD_DIM = 72
def prepare_data(
bsz,
num_kv_heads,
block_size,
max_num_blocks_per_seq,
context_lengths,
device="cuda",
dtype=torch.float16,
):
num_tokens = torch.sum(context_lengths).item()
max_seq_len_in_batch = context_lengths.max()
cu_seqlens = F.pad(torch.cumsum(context_lengths, dim=0, dtype=torch.torch.int32), (1, 0))
kv_size = (num_tokens, num_kv_heads, HEAD_DIM)
key = torch.empty(size=kv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5)
value = torch.empty(size=kv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5)
k_cache_ref, v_cache_ref, block_tables = generate_caches_and_block_tables_v3(
key, value, context_lengths, bsz, max_num_blocks_per_seq, block_size, dtype, device
)
block_tables = block_tables.to(device=device)
k_cache = torch.zeros_like(k_cache_ref)
v_cache = torch.zeros_like(v_cache_ref)
return key, value, k_cache, v_cache, cu_seqlens, block_tables, max_seq_len_in_batch, k_cache_ref, v_cache_ref
def run_decode_copy_kv_to_caches( def run_decode_copy_kv_to_caches(
@ -24,32 +52,41 @@ def run_decode_copy_kv_to_caches(
torch.cuda.synchronize() torch.cuda.synchronize()
torch.cuda.reset_peak_memory_stats() torch.cuda.reset_peak_memory_stats()
n = 1
max_seq_len = block_size * max_num_blocks_per_seq max_seq_len = block_size * max_num_blocks_per_seq
dtype = torch.float32 dtype = torch.float32
device = get_current_device() device = get_current_device()
new_k, new_v, k_cache, v_cache, kv_seq_lengths, block_tables = prepare_data( assert max_seq_len > n, "max_seq_len must be greater than n"
bsz,
num_kv_heads, past_kv_seq_lengths = (
HEAD_DIM, torch.tensor([max_seq_len - n for _ in range(bsz)], dtype=torch.int32, device=device)
block_size, if same_context_len
max_num_blocks_per_seq, else torch.randint(low=1, high=max_seq_len - n, size=(bsz,), dtype=torch.int32, device=device)
same_context_len,
max_seq_len,
device=device,
dtype=dtype,
) )
new_k = new_k.squeeze(1) if new_k.dim() == 4 else new_k key, value, k_cache, v_cache, _, block_tables, _, _, _ = prepare_data(
new_v = new_v.squeeze(1) if new_v.dim() == 4 else new_v bsz, num_kv_heads, block_size, max_num_blocks_per_seq, past_kv_seq_lengths, device, dtype
inference_ops.decode_kv_cache_memcpy(new_k, new_v, k_cache, v_cache, kv_seq_lengths, block_tables) )
past_kv_seq_len = kv_seq_lengths - 1 new_k = torch.randn((bsz, num_kv_heads, HEAD_DIM), dtype=dtype, device=device)
new_v = torch.randn((bsz, num_kv_heads, HEAD_DIM), dtype=dtype, device=device)
# mock allocating blocks for the new k/v and update block tables
for _ in range(n):
mock_alloc_single_token(block_tables, past_kv_seq_lengths, block_size)
past_kv_seq_lengths += 1
inference_ops.decode_kv_cache_memcpy(new_k, new_v, k_cache, v_cache, past_kv_seq_lengths, block_tables)
past_kv_seq_len = past_kv_seq_lengths - 1
target_block_ids = block_tables[range(0, block_tables.size(0)), past_kv_seq_len // block_size] target_block_ids = block_tables[range(0, block_tables.size(0)), past_kv_seq_len // block_size]
offsets_in_block = past_kv_seq_len % block_size offsets_in_block = past_kv_seq_len % block_size
k_target = k_cache[target_block_ids, :, offsets_in_block, :] k_target = k_cache[target_block_ids, :, :, offsets_in_block, :]
k_source = new_k.squeeze() k_source = new_k.squeeze()
v_target = v_cache[target_block_ids, :, offsets_in_block, :] v_target = v_cache[target_block_ids, :, offsets_in_block, :]
k_target = k_target.reshape(v_target.shape)
v_source = new_v.squeeze() v_source = new_v.squeeze()
assert k_target.shape == k_source.shape assert k_target.shape == k_source.shape
@ -77,22 +114,17 @@ def run_context_copy_kv_to_cache(
else: else:
context_lengths = torch.randint(low=1, high=max_seq_len, size=(bsz,), dtype=torch.int32, device=device) context_lengths = torch.randint(low=1, high=max_seq_len, size=(bsz,), dtype=torch.int32, device=device)
num_tokens = torch.sum(context_lengths).item() (
key,
max_seq_len_in_batch = context_lengths.max() value,
cu_seqlens = F.pad(torch.cumsum(context_lengths, dim=0, dtype=torch.torch.int32), (1, 0)) k_cache,
v_cache,
kv_size = (num_tokens, num_kv_heads, HEAD_DIM) cu_seqlens,
key = torch.empty(size=kv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5) block_tables,
value = torch.empty(size=kv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5) max_seq_len_in_batch,
k_cache_ref,
k_cache_ref, v_cache_ref, block_tables = generate_caches_and_block_tables_v2( v_cache_ref,
key, value, context_lengths, bsz, max_num_blocks_per_seq, block_size, dtype, device ) = prepare_data(bsz, num_kv_heads, block_size, max_num_blocks_per_seq, context_lengths, device, dtype)
)
block_tables = block_tables.to(device=device)
k_cache = torch.zeros_like(k_cache_ref)
v_cache = torch.zeros_like(v_cache_ref)
inference_ops.context_kv_cache_memcpy( inference_ops.context_kv_cache_memcpy(
key, value, k_cache, v_cache, context_lengths, cu_seqlens, block_tables, max_seq_len_in_batch key, value, k_cache, v_cache, context_lengths, cu_seqlens, block_tables, max_seq_len_in_batch

View File

@ -7,7 +7,7 @@ from colossalai.kernel.kernel_loader import InferenceOpsLoader
inference_ops = InferenceOpsLoader().load() inference_ops = InferenceOpsLoader().load()
from tests.test_infer.test_ops.triton.kernel_utils import mock_alloc_block_table_and_kvcache_v2 from tests.test_infer.test_ops.triton.kernel_utils import mock_alloc_block_table_and_kvcache_v3
from tests.test_infer.test_ops.triton.test_rotary_embdding_unpad import torch_rotary_emb from tests.test_infer.test_ops.triton.test_rotary_embdding_unpad import torch_rotary_emb
@ -49,12 +49,14 @@ def test_rotary_emb(BATCH_SIZE, SEQ_LEN, H, K_H, D, dtype):
cos_shape = (TOTAL_TOKENS, D // 2) cos_shape = (TOTAL_TOKENS, D // 2)
cos = -1.2 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda") cos = -1.2 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
sin = -2.0 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda") sin = -2.0 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
cache_shape = (BATCH_SIZE * max_blocks_per_sequence, K_H, block_size, D) x = 16 // torch.tensor([], dtype=dtype).element_size()
k_cache = torch.zeros(size=cache_shape, dtype=dtype, device="cuda") k_cache_shape = (BATCH_SIZE * max_blocks_per_sequence, K_H, D // x, block_size, x)
v_cache_shape = (BATCH_SIZE * max_blocks_per_sequence, K_H, block_size, D)
k_cache = torch.zeros(size=k_cache_shape, dtype=dtype, device="cuda")
v = torch.randn_like(k) v = torch.randn_like(k)
v_cache = torch.zeros_like(k_cache) v_cache = torch.zeros(size=v_cache_shape, dtype=dtype, device="cuda")
past_kv_seq_lengths = torch.tensor([SEQ_LEN - 1 for _ in range(BATCH_SIZE)], dtype=torch.int32, device="cuda") past_kv_seq_lengths = torch.tensor([SEQ_LEN - 1 for _ in range(BATCH_SIZE)], dtype=torch.int32, device="cuda")
block_tables = mock_alloc_block_table_and_kvcache_v2( block_tables = mock_alloc_block_table_and_kvcache_v3(
k, v, k_cache, v_cache, past_kv_seq_lengths, BATCH_SIZE, max_blocks_per_sequence, block_size k, v, k_cache, v_cache, past_kv_seq_lengths, BATCH_SIZE, max_blocks_per_sequence, block_size
) )
new_k = torch.randn((BATCH_SIZE, K_H, D), dtype=dtype, device="cuda") new_k = torch.randn((BATCH_SIZE, K_H, D), dtype=dtype, device="cuda")
@ -97,9 +99,10 @@ def test_rotary_emb(BATCH_SIZE, SEQ_LEN, H, K_H, D, dtype):
past_kv_seq_len = kv_seq_lengths - 1 past_kv_seq_len = kv_seq_lengths - 1
target_block_ids = block_tables[range(0, block_tables.size(0)), past_kv_seq_len // block_size] target_block_ids = block_tables[range(0, block_tables.size(0)), past_kv_seq_len // block_size]
offsets_in_block = past_kv_seq_len % block_size offsets_in_block = past_kv_seq_len % block_size
k_target = k_cache[target_block_ids, :, offsets_in_block, :].squeeze() k_target = k_cache[target_block_ids, :, :, offsets_in_block, :].squeeze()
k_source = new_k_copy.squeeze() k_source = new_k_copy.squeeze()
v_target = v_cache[target_block_ids, :, offsets_in_block, :].squeeze() v_target = v_cache[target_block_ids, :, offsets_in_block, :].squeeze()
k_target = k_target.reshape(v_target.shape)
v_source = new_v.squeeze() v_source = new_v.squeeze()
numpy_allclose(new_q, q_ref, rtol=rtol, atol=atol) numpy_allclose(new_q, q_ref, rtol=rtol, atol=atol)