mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 18:40:28 +00:00
[Kernel/Fix] Revise flash attention triton kernel API and add benchmark (#5301)
* fix decoding kernel pytest * revise and add triton context attn benchmark
This commit is contained in:
@@ -87,7 +87,7 @@ class PagedAttention:
|
||||
Transform 1D no_pad tensor into 2D padded tensor with shape [bsz,seq_len,num_heads,head_size]
|
||||
"""
|
||||
bsz = len(seq_lengths)
|
||||
padded_tensor = torch.zeros(bsz, max_seq_len, num_heads, head_size)
|
||||
padded_tensor = torch.zeros(bsz, max_seq_len, num_heads, head_size, dtype=tensor.dtype)
|
||||
|
||||
token_idx = 0
|
||||
for i, seq_len in enumerate(seq_lengths):
|
||||
|
Reference in New Issue
Block a user