mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-11 13:59:08 +00:00
[Infer] Optimize Blocked KVCache And Kernels Using It (#5325)
* revise shape of kvcache (context attn kernel) * revise shape of kvcache (flash decoding kernel) * revise shape of kvcache (kvcache copy) and attn func * init of kvcache in kvcache manager * revise llama modeling * revise block size retrieval * use torch for rms_norm benchmarking * revise block size retrieval
This commit is contained in:
@@ -1,20 +1,17 @@
|
||||
import pytest
|
||||
import torch
|
||||
from transformers.cache_utils import DynamicCache
|
||||
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
|
||||
from transformers.models.llama.configuration_llama import LlamaConfig
|
||||
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb
|
||||
|
||||
import colossalai
|
||||
from colossalai.inference.modeling.layers.attention import PagedAttention, convert_kvcache, copy_to_cache
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
|
||||
|
||||
def test_copy_to_cache():
|
||||
key = torch.ones((2, 11, 3, 3))
|
||||
key[0, 9, :, :] = 0
|
||||
key[1, -2:, :, :] = 0
|
||||
cache = torch.zeros(8, 3, 3, 8)
|
||||
cache = torch.zeros(8, 3, 8, 3)
|
||||
block_tables = torch.tensor([[0, 1], [2, 3]])
|
||||
lengths = torch.tensor([9, 8])
|
||||
cache = copy_to_cache(key, cache=cache, lengths=lengths, block_tables=block_tables, type="prefill")
|
||||
@@ -28,7 +25,7 @@ def test_copy_to_cache():
|
||||
|
||||
|
||||
def test_convert_kvcache():
|
||||
cache = torch.ones(8, 3, 3, 8)
|
||||
cache = torch.ones(8, 3, 8, 3)
|
||||
key = torch.ones(2, 1, 3, 3) + 1
|
||||
lengths = torch.tensor([10, 9])
|
||||
block_tables = torch.tensor([[0, 1], [2, 3]])
|
||||
@@ -43,8 +40,8 @@ def test_context_attention():
|
||||
"""
|
||||
attn = PagedAttention()
|
||||
q = k = v = torch.randn(8, 4, 4)
|
||||
k_cache = torch.empty(8, 4, 4, 8)
|
||||
v_cache = torch.empty(8, 4, 4, 8)
|
||||
k_cache = torch.empty(8, 4, 8, 4)
|
||||
v_cache = torch.empty(8, 4, 8, 4)
|
||||
context_lengths = torch.tensor(
|
||||
[
|
||||
8,
|
||||
@@ -136,23 +133,8 @@ def test_decoding_attention():
|
||||
assert torch.allclose(pad_attn_output, attn_output, atol=1e-3, rtol=1e-2)
|
||||
|
||||
|
||||
def check_attention_layer():
|
||||
if __name__ == "__main__":
|
||||
test_copy_to_cache()
|
||||
test_convert_kvcache()
|
||||
test_context_attention()
|
||||
test_decoding_attention()
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost")
|
||||
check_attention_layer()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_attention_layer():
|
||||
spawn(run_dist, 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_attention_layer()
|
||||
|
Reference in New Issue
Block a user