[Infer] Optimize Blocked KVCache And Kernels Using It (#5325)

* revise shape of kvcache (context attn kernel)

* revise shape of kvcache (flash decoding kernel)

* revise shape of kvcache (kvcache copy) and attn func

* init of kvcache in kvcache manager

* revise llama modeling

* revise block size retrieval

* use torch for rms_norm benchmarking

* revise block size retrieval
This commit is contained in:
Yuanheng Zhao
2024-01-30 16:06:09 +08:00
committed by GitHub
parent e8f0642f28
commit 5f98a9d68a
14 changed files with 171 additions and 145 deletions

View File

@@ -93,7 +93,7 @@ def check_cache_manager(test_config):
assert len(cache_manager._cache_blocks) == num_blocks
key_caches = cache_manager._kv_caches[0] # key caches for all the blocks in all the layers
assert len(key_caches) == num_layers
expected_kv_shape = (num_blocks, num_attention_heads, head_size, block_size)
expected_kv_shape = (num_blocks, num_attention_heads, block_size, head_size)
assert key_caches[0].shape == expected_kv_shape
k_cache_block0, v_cache_block0 = cache_manager.get_physical_cache(0, 0)
expected_kv_block_shape = expected_kv_shape[1:]

View File

@@ -1,20 +1,17 @@
import pytest
import torch
from transformers.cache_utils import DynamicCache
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb
import colossalai
from colossalai.inference.modeling.layers.attention import PagedAttention, convert_kvcache, copy_to_cache
from colossalai.testing import rerun_if_address_is_in_use, spawn
def test_copy_to_cache():
key = torch.ones((2, 11, 3, 3))
key[0, 9, :, :] = 0
key[1, -2:, :, :] = 0
cache = torch.zeros(8, 3, 3, 8)
cache = torch.zeros(8, 3, 8, 3)
block_tables = torch.tensor([[0, 1], [2, 3]])
lengths = torch.tensor([9, 8])
cache = copy_to_cache(key, cache=cache, lengths=lengths, block_tables=block_tables, type="prefill")
@@ -28,7 +25,7 @@ def test_copy_to_cache():
def test_convert_kvcache():
cache = torch.ones(8, 3, 3, 8)
cache = torch.ones(8, 3, 8, 3)
key = torch.ones(2, 1, 3, 3) + 1
lengths = torch.tensor([10, 9])
block_tables = torch.tensor([[0, 1], [2, 3]])
@@ -43,8 +40,8 @@ def test_context_attention():
"""
attn = PagedAttention()
q = k = v = torch.randn(8, 4, 4)
k_cache = torch.empty(8, 4, 4, 8)
v_cache = torch.empty(8, 4, 4, 8)
k_cache = torch.empty(8, 4, 8, 4)
v_cache = torch.empty(8, 4, 8, 4)
context_lengths = torch.tensor(
[
8,
@@ -136,23 +133,8 @@ def test_decoding_attention():
assert torch.allclose(pad_attn_output, attn_output, atol=1e-3, rtol=1e-2)
def check_attention_layer():
if __name__ == "__main__":
test_copy_to_cache()
test_convert_kvcache()
test_context_attention()
test_decoding_attention()
def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost")
check_attention_layer()
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_attention_layer():
spawn(run_dist, 1)
if __name__ == "__main__":
test_attention_layer()