[kernel/fix] Performance Optimization for Decoding Kernel and Benchmarking (#5274)

* prevent re-creating intermediate tensors

* add singleton class holding intermediate values

* fix triton kernel api

* add benchmark in pytest

* fix kernel api and add benchmark

* revise flash decoding triton kernel in/out shapes

* fix calling of triton kernel in modeling

* fix pytest: extract to util functions
This commit is contained in:
Yuanheng Zhao
2024-01-19 15:47:16 +08:00
committed by GitHub
parent 9e2342bde2
commit 6e487e7d3c
7 changed files with 382 additions and 152 deletions

View File

@@ -6,7 +6,7 @@ from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecode
from colossalai.inference.modeling.layers.attention import PagedAttention
from colossalai.inference.struct import BatchInfo
from colossalai.kernel.triton import context_attention_unpadded, copy_kv_to_blocked_cache, flash_decoding_fwd
from colossalai.kernel.triton import context_attention_unpadded, copy_kv_to_blocked_cache, flash_decoding_attention
from colossalai.logging import get_dist_logger
from flash_attn.bert_padding import index_first_axis, pad_input # noqa
@@ -209,7 +209,15 @@ def llama_attn_forward(
if HAS_TRITON:
copy_kv_to_blocked_cache(key_states, k_cache, kv_lengths=sequence_lengths, block_tables=block_tables)
copy_kv_to_blocked_cache(value_states, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables)
attn_output = flash_decoding_fwd(query_states, k_cache, v_cache, sequence_lengths, block_tables, block_size)
# TODO Add dummy transpose and squeeze on in/out tensors of the triton flash decoding kernel
# in order to maintain compatibility. This part as well as the logics of handling query_states and attn_output
# should be revised, as we could see in previous part of `llama_attn_forward` we still have
# redundant transpose and the in/out of torch- and triton-version decoding kernel are not consistent.
query_states = query_states.transpose(1, 2)
attn_output = flash_decoding_attention(
query_states, k_cache, v_cache, sequence_lengths, block_tables, block_size
)
attn_output = attn_output.squeeze(1)
else:
attn_output = PagedAttention.pad_decoding_forward(
query_states, key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables, attention_mask