[Kernel/Fix] Revise flash attention triton kernel API and add benchmark (#5301)

* fix decoding kernel pytest

* revise and add triton context attn benchmark
This commit is contained in:
Yuanheng Zhao
2024-01-23 17:16:02 +08:00
committed by GitHub
parent 8e606ecc7e
commit 3da9993b0d
5 changed files with 116 additions and 15 deletions

View File

@@ -87,7 +87,7 @@ class PagedAttention:
Transform 1D no_pad tensor into 2D padded tensor with shape [bsz,seq_len,num_heads,head_size]
"""
bsz = len(seq_lengths)
padded_tensor = torch.zeros(bsz, max_seq_len, num_heads, head_size)
padded_tensor = torch.zeros(bsz, max_seq_len, num_heads, head_size, dtype=tensor.dtype)
token_idx = 0
for i, seq_len in enumerate(seq_lengths):