mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 11:32:10 +00:00
[kernel/fix] Performance Optimization for Decoding Kernel and Benchmarking (#5274)
* prevent re-creating intermediate tensors * add singleton class holding intermediate values * fix triton kernel api * add benchmark in pytest * fix kernel api and add benchmark * revise flash decoding triton kernel in/out shapes * fix calling of triton kernel in modeling * fix pytest: extract to util functions
This commit is contained in:
@@ -6,7 +6,7 @@ from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecode
|
||||
|
||||
from colossalai.inference.modeling.layers.attention import PagedAttention
|
||||
from colossalai.inference.struct import BatchInfo
|
||||
from colossalai.kernel.triton import context_attention_unpadded, copy_kv_to_blocked_cache, flash_decoding_fwd
|
||||
from colossalai.kernel.triton import context_attention_unpadded, copy_kv_to_blocked_cache, flash_decoding_attention
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
from flash_attn.bert_padding import index_first_axis, pad_input # noqa
|
||||
@@ -209,7 +209,15 @@ def llama_attn_forward(
|
||||
if HAS_TRITON:
|
||||
copy_kv_to_blocked_cache(key_states, k_cache, kv_lengths=sequence_lengths, block_tables=block_tables)
|
||||
copy_kv_to_blocked_cache(value_states, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables)
|
||||
attn_output = flash_decoding_fwd(query_states, k_cache, v_cache, sequence_lengths, block_tables, block_size)
|
||||
# TODO Add dummy transpose and squeeze on in/out tensors of the triton flash decoding kernel
|
||||
# in order to maintain compatibility. This part as well as the logics of handling query_states and attn_output
|
||||
# should be revised, as we could see in previous part of `llama_attn_forward` we still have
|
||||
# redundant transpose and the in/out of torch- and triton-version decoding kernel are not consistent.
|
||||
query_states = query_states.transpose(1, 2)
|
||||
attn_output = flash_decoding_attention(
|
||||
query_states, k_cache, v_cache, sequence_lengths, block_tables, block_size
|
||||
)
|
||||
attn_output = attn_output.squeeze(1)
|
||||
else:
|
||||
attn_output = PagedAttention.pad_decoding_forward(
|
||||
query_states, key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables, attention_mask
|
||||
|
Reference in New Issue
Block a user