mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-13 13:11:05 +00:00
[kernel] Add flash decoding triton kernel for blocked kv cache (#5249)
* add flash decoding unpad triton kernel * rename flash decoding kernel * add kernel testing (draft) * revise pytest * support kv group (GQA) * (trivial) fix api and pytest * (trivial) func renaming * (trivial) func/file renaming * refactor pytest for attention * (trivial) format and consistent vars of context/decode attn * (trivial) remove test redundancy
This commit is contained in:
committed by
FrankLeeeee
parent
fded91d049
commit
1513f20f4d
@@ -1,27 +1,102 @@
|
||||
import math
|
||||
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
def torch_context_attention(xq, xk, xv, bs, seqlen, num_head, head_dim):
|
||||
# This function is adapted from src/transformers/models/llama/modeling_llama.py
|
||||
# in huggingface transformers repository
|
||||
# https://github.com/huggingface/transformers/blob/3b7675b2b844b02d4821b827871a21ad16dd446c/src/transformers/models/llama/modeling_llama.py#L273
|
||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
"""
|
||||
adepted from https://github.com/ModelTC/lightllm/blob/main/lightllm/models/bloom/triton_kernel/context_flashattention_nopad.py#L253
|
||||
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep).
|
||||
The hidden states go from (bsz, num_key_value_heads, seq_len, head_dim) to (bsz, num_attention_heads, seq_len, head_dim)
|
||||
"""
|
||||
xq = xq.view(bs, seqlen, num_head, head_dim)
|
||||
xk = xk.view(bs, seqlen, num_head, head_dim)
|
||||
xv = xv.view(bs, seqlen, num_head, head_dim)
|
||||
mask = torch.tril(torch.ones(seqlen, seqlen), diagonal=0).unsqueeze(0).unsqueeze(0).cuda()
|
||||
mask[mask == 0.0] = -100000000.0
|
||||
mask = mask.repeat(bs, num_head, 1, 1)
|
||||
keys = xk
|
||||
values = xv
|
||||
xq = xq.transpose(1, 2)
|
||||
keys = keys.transpose(1, 2)
|
||||
values = values.transpose(1, 2)
|
||||
sm_scale = 1 / math.sqrt(head_dim)
|
||||
scores = torch.matmul(xq, keys.transpose(2, 3)) * sm_scale
|
||||
scores = F.softmax(scores.float() + mask, dim=-1).to(dtype=torch.float16)
|
||||
if n_rep == 1:
|
||||
return hidden_states
|
||||
bsz, num_key_value_heads, seq_len, head_dim = hidden_states.shape
|
||||
hidden_states = hidden_states[:, :, None, :, :].expand(bsz, num_key_value_heads, n_rep, seq_len, head_dim)
|
||||
return hidden_states.reshape(bsz, num_key_value_heads * n_rep, seq_len, head_dim)
|
||||
|
||||
output = torch.matmul(scores, values).transpose(1, 2).contiguous().reshape(-1, num_head, head_dim)
|
||||
return output
|
||||
|
||||
# Attention calculation adapted from HuggingFace transformers repository
|
||||
# src/transformers/models/llama/modeling_llama.py
|
||||
# https://github.com/huggingface/transformers/blob/633215ba58fe5114d8c8d32e415a04600e010701/src/transformers/models/llama/modeling_llama.py#L350
|
||||
def torch_attn_ref(
|
||||
q: torch.Tensor, # [bsz, seq_len, num_heads, head_dim]
|
||||
k: torch.Tensor, # [bsz, kv_seq_len, num_heads, head_dim]
|
||||
v: torch.Tensor, # [bsz, kv_seq_len, num_heads, head_dim]
|
||||
attention_mask: torch.Tensor, # [bsz, 1, seq_len, kv_seq_len]
|
||||
bsz: int,
|
||||
seq_len: int,
|
||||
kv_seq_len: int,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
head_dim: int,
|
||||
):
|
||||
assert q.shape[-1] == k.shape[-1] == v.shape[-1] == head_dim
|
||||
q = q.view(bsz, seq_len, num_heads, head_dim)
|
||||
k = k.view(bsz, kv_seq_len, num_kv_heads, head_dim)
|
||||
v = v.view(bsz, kv_seq_len, num_kv_heads, head_dim)
|
||||
q = q.transpose(1, 2)
|
||||
k = k.transpose(1, 2)
|
||||
v = v.transpose(1, 2)
|
||||
|
||||
# repeat kv for GQA and MQA
|
||||
# k/v won't change if kv_group_num is 1
|
||||
assert num_heads % num_kv_heads == 0, "Number of heads is not multiple of kv heads"
|
||||
kv_group_num = num_heads // num_kv_heads
|
||||
k = repeat_kv(k, kv_group_num)
|
||||
v = repeat_kv(v, kv_group_num)
|
||||
|
||||
qk = torch.matmul(q, k.transpose(2, 3))
|
||||
attn_scores = qk / (head_dim**0.5)
|
||||
|
||||
assert attn_scores.shape == (bsz, num_heads, seq_len, kv_seq_len), "Invalid shape of attention scores"
|
||||
# for left-side padding
|
||||
if attention_mask.size() != (bsz, 1, seq_len, kv_seq_len):
|
||||
raise ValueError(
|
||||
f"Attention mask should be of size {(bsz, 1, seq_len, kv_seq_len)}, but is {attention_mask.size()}"
|
||||
)
|
||||
|
||||
attn_scores = attn_scores + attention_mask
|
||||
attn_weights = F.softmax(attn_scores.to(dtype=torch.float32), dim=-1).to(dtype=q.dtype)
|
||||
out = torch.matmul(attn_weights, v)
|
||||
if out.size() != (bsz, num_heads, seq_len, head_dim):
|
||||
raise ValueError(
|
||||
f"`attn_output` should be of size {(bsz, num_heads, seq_len, head_dim)}, but is" f" {out.size()}"
|
||||
)
|
||||
out = out.transpose(1, 2).contiguous()
|
||||
return out
|
||||
|
||||
|
||||
def mock_alloc_block_table_and_kvcache(
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
k_cache: torch.Tensor,
|
||||
v_cache: torch.Tensor,
|
||||
context_lengths: torch.Tensor,
|
||||
num_seqs: int,
|
||||
max_num_blocks_per_seq: int,
|
||||
block_size: int,
|
||||
):
|
||||
"""Allocate block tables based on provided context lengths; and copy KV to blocked KV Cache."""
|
||||
block_id = 0
|
||||
block_tables = torch.full(size=(num_seqs, max_num_blocks_per_seq), fill_value=-1, dtype=torch.int32)
|
||||
num_tokens_processed = 0
|
||||
for i, seq_len in enumerate(context_lengths.tolist()):
|
||||
right_bound = (seq_len + block_size - 1) // block_size # open bound
|
||||
block_tables[i, :right_bound] = torch.arange(block_id, block_id + right_bound, dtype=torch.int32)
|
||||
# Manually fill kv caches by copying from k and v
|
||||
for i in range(right_bound):
|
||||
if i == right_bound - 1:
|
||||
allocated_locs = seq_len % block_size or block_size
|
||||
else:
|
||||
allocated_locs = block_size
|
||||
k_block = k[num_tokens_processed : num_tokens_processed + allocated_locs, :, :].permute(1, 2, 0)
|
||||
v_block = v[num_tokens_processed : num_tokens_processed + allocated_locs, :, :].permute(1, 2, 0)
|
||||
k_cache[block_id, :, :, :allocated_locs] = k_block
|
||||
v_cache[block_id, :, :, :allocated_locs] = v_block
|
||||
|
||||
num_tokens_processed += allocated_locs
|
||||
block_id += 1
|
||||
|
||||
return block_tables
|
||||
|
@@ -1,10 +1,10 @@
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from packaging import version
|
||||
|
||||
from colossalai.kernel.triton import context_attention_unpadded
|
||||
from colossalai.utils import get_current_device
|
||||
from tests.test_infer_ops.triton.kernel_utils import mock_alloc_block_table_and_kvcache, torch_attn_ref
|
||||
|
||||
try:
|
||||
import triton # noqa
|
||||
@@ -17,60 +17,40 @@ except ImportError:
|
||||
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4")
|
||||
|
||||
|
||||
def torch_attn_ref(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, seq_len: int, num_heads: int, head_size: int):
|
||||
# For a single sequence, q,k,v [seq_len, num_heads, head_size]
|
||||
assert q.shape[-1] == k.shape[-1] == v.shape[-1] == head_size
|
||||
q = q.view(seq_len, num_heads, head_size)
|
||||
k = k.view(seq_len, num_heads, head_size)
|
||||
v = v.view(seq_len, num_heads, head_size)
|
||||
q = q.transpose(0, 1)
|
||||
k = k.transpose(0, 1)
|
||||
v = v.transpose(0, 1)
|
||||
|
||||
mask = torch.tril(torch.ones(1, seq_len, seq_len), diagonal=0).to(device=get_current_device())
|
||||
mask[mask == 0.0] = float("-inf")
|
||||
mask = mask.repeat(num_heads, 1, 1)
|
||||
|
||||
qk = torch.matmul(q, k.transpose(1, 2))
|
||||
attn_scores = qk / (head_size**0.5)
|
||||
attn_weights = F.softmax(attn_scores.to(dtype=torch.float32) + mask, dim=-1).to(dtype=q.dtype)
|
||||
out = torch.matmul(attn_weights, v).transpose(0, 1).contiguous()
|
||||
out = out.reshape(-1, num_heads, head_size)
|
||||
return out
|
||||
|
||||
|
||||
def torch_attn_unpad(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, context_lengths: torch.Tensor):
|
||||
# Process sequence one by one and cat them together.
|
||||
# q,k,v [num_tokens(sum(context_lengths)), num_heads, head_size]
|
||||
def torch_attn_unpad(
|
||||
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, context_lengths: torch.Tensor, num_heads: int, num_kv_heads: int
|
||||
):
|
||||
# Process sequence one by one and concatenate them together.
|
||||
# q,k,v [num_tokens(sum(context_lengths)), num_heads, head_dim]
|
||||
assert context_lengths.dim() == 1, "context_lengths should be a 1D tensor"
|
||||
_, num_heads, head_size = q.shape
|
||||
|
||||
_, num_heads, head_dim = q.shape
|
||||
out_torch = []
|
||||
start_idx = 0
|
||||
for i in range(len(context_lengths)):
|
||||
end_idx = start_idx + context_lengths[i].item()
|
||||
for seq_i in range(len(context_lengths)):
|
||||
end_idx = start_idx + context_lengths[seq_i].item()
|
||||
seq_len = end_idx - start_idx
|
||||
mask = torch.tril(torch.ones(1, 1, seq_len, seq_len), diagonal=0).to(device=q.device)
|
||||
mask[mask == 0.0] = float("-inf")
|
||||
|
||||
torch_attn_ref_out = torch_attn_ref(
|
||||
q[start_idx:end_idx], k[start_idx:end_idx], v[start_idx:end_idx], end_idx - start_idx, num_heads, head_size
|
||||
q[start_idx:end_idx].unsqueeze(0),
|
||||
k[start_idx:end_idx].unsqueeze(0),
|
||||
v[start_idx:end_idx].unsqueeze(0),
|
||||
mask,
|
||||
1, # set bsz as 1 as we're processing sequence one by one
|
||||
seq_len,
|
||||
seq_len,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
)
|
||||
out_torch.append(torch_attn_ref_out)
|
||||
out_torch.append(torch_attn_ref_out.squeeze(0))
|
||||
start_idx = end_idx
|
||||
|
||||
return torch.cat(out_torch, dim=0)
|
||||
|
||||
|
||||
# This method is adapted from src/transformers/models/llama/modeling_llama.py
|
||||
# in transformers repository https://github.com/huggingface/transformers
|
||||
# https://github.com/huggingface/transformers/blob/3b7675b2b844b02d4821b827871a21ad16dd446c/src/transformers/models/llama/modeling_llama.py#L273
|
||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
"""
|
||||
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (num_tokens,
|
||||
num_key_value_heads, head_dim) to (num_tokens, num_attention_heads, head_dim)
|
||||
"""
|
||||
num_tokens, num_key_value_heads, head_dim = hidden_states.shape
|
||||
if n_rep == 1:
|
||||
return hidden_states
|
||||
hidden_states = hidden_states[:, :, None, :].expand(num_tokens, num_key_value_heads, n_rep, head_dim)
|
||||
return hidden_states.reshape(num_tokens, num_key_value_heads * n_rep, head_dim)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not (HAS_TRITON and TRITON_CUDA_SUPPORT), reason="requires triton")
|
||||
@pytest.mark.parametrize("bsz", [4, 7, 32])
|
||||
@pytest.mark.parametrize("block_size", [16, 32, 64])
|
||||
@@ -87,72 +67,46 @@ def test_context_attention(
|
||||
same_context_len: bool,
|
||||
):
|
||||
torch.manual_seed(123)
|
||||
|
||||
dtype = torch.float16
|
||||
device = get_current_device()
|
||||
num_seqs = bsz
|
||||
num_kv_heads = num_attn_heads // kv_group_num
|
||||
assert isinstance(num_kv_heads, int) and num_kv_heads > 0, "Invalid number of kv heads."
|
||||
head_size = 32
|
||||
max_seq_len = max_num_blocks_per_seq * block_size
|
||||
|
||||
# It's necessary to clear cache here.
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.synchronize()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
num_kv_heads = num_attn_heads // kv_group_num
|
||||
assert isinstance(num_kv_heads, int) and num_kv_heads > 0, "Invalid number of kv heads."
|
||||
head_dim = 32
|
||||
max_seq_len = max_num_blocks_per_seq * block_size
|
||||
dtype = torch.float16
|
||||
device = get_current_device()
|
||||
|
||||
if same_context_len:
|
||||
context_lengths = torch.tensor([max_seq_len for _ in range(num_seqs)], dtype=torch.int32, device=device)
|
||||
context_lengths = torch.tensor([max_seq_len for _ in range(bsz)], dtype=torch.int32, device=device)
|
||||
else:
|
||||
context_lengths = torch.randint(low=1, high=max_seq_len, size=(num_seqs,), dtype=torch.int32, device=device)
|
||||
context_lengths = torch.randint(low=1, high=max_seq_len, size=(bsz,), dtype=torch.int32, device=device)
|
||||
num_tokens = torch.sum(context_lengths).item()
|
||||
|
||||
qkv_size = (num_tokens, num_attn_heads + 2 * num_kv_heads, head_size)
|
||||
qkv_size = (num_tokens, num_attn_heads + 2 * num_kv_heads, head_dim)
|
||||
qkv = torch.empty(size=qkv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5)
|
||||
q, k, v = torch.split(qkv, [num_attn_heads, num_kv_heads, num_kv_heads], dim=-2)
|
||||
|
||||
cache_shape = (bsz * max_num_blocks_per_seq, num_kv_heads, head_size, block_size)
|
||||
cache_shape = (bsz * max_num_blocks_per_seq, num_kv_heads, head_dim, block_size)
|
||||
k_cache_torch = torch.zeros(size=cache_shape, dtype=dtype, device=device)
|
||||
k_cache_triton = torch.zeros_like(k_cache_torch)
|
||||
v_cache_torch = torch.zeros(size=cache_shape, dtype=dtype, device=device)
|
||||
v_cache_triton = torch.zeros_like(v_cache_torch)
|
||||
|
||||
# Mock allocation on block tables
|
||||
block_id = 0
|
||||
block_tables = torch.full(size=(num_seqs, max_num_blocks_per_seq), fill_value=-1, dtype=torch.int32)
|
||||
num_tokens_processed = 0
|
||||
for i, seq_len in enumerate(context_lengths.tolist()):
|
||||
right_bound = (seq_len + block_size - 1) // block_size # open bound
|
||||
block_tables[i, :right_bound] = torch.arange(block_id, block_id + right_bound, dtype=torch.int32)
|
||||
# Manually fill k_cache_torch and v_cache_torch by copying from k and v
|
||||
for i in range(right_bound):
|
||||
if i == right_bound - 1:
|
||||
allocated_locs = seq_len % block_size or block_size
|
||||
else:
|
||||
allocated_locs = block_size
|
||||
k_block = k[num_tokens_processed : num_tokens_processed + allocated_locs, :, :].permute(1, 2, 0)
|
||||
v_block = v[num_tokens_processed : num_tokens_processed + allocated_locs, :, :].permute(1, 2, 0)
|
||||
cur_block_size_occupied = k_block.shape[-1]
|
||||
assert cur_block_size_occupied <= block_size, "Invalid occupied size of block during mock allocation"
|
||||
k_cache_torch[block_id, :, :, :cur_block_size_occupied] = k_block
|
||||
v_cache_torch[block_id, :, :, :cur_block_size_occupied] = v_block
|
||||
|
||||
num_tokens_processed += allocated_locs
|
||||
block_id += 1
|
||||
|
||||
block_tables = mock_alloc_block_table_and_kvcache(
|
||||
k, v, k_cache_torch, v_cache_torch, context_lengths, bsz, max_num_blocks_per_seq, block_size
|
||||
)
|
||||
block_tables = block_tables.to(device=device)
|
||||
out_triton = context_attention_unpadded(
|
||||
q, k, v, k_cache_triton, v_cache_triton, context_lengths, block_tables, block_size
|
||||
)
|
||||
|
||||
# For GQA and MQA, repeat k, v for torch attention calculation
|
||||
# k/v won't change if provided `num_kv_group` is 1
|
||||
num_kv_group = num_attn_heads // num_kv_heads
|
||||
k = repeat_kv(k, num_kv_group)
|
||||
v = repeat_kv(v, num_kv_group)
|
||||
out_torch = torch_attn_unpad(q, k, v, context_lengths)
|
||||
out_torch = torch_attn_unpad(q, k, v, context_lengths, num_attn_heads, num_kv_heads)
|
||||
|
||||
assert out_torch.shape == out_triton.shape
|
||||
assert torch.allclose(out_torch, out_triton, atol=1e-2, rtol=1e-3)
|
||||
assert torch.allclose(out_torch, out_triton, atol=1e-3, rtol=1e-4)
|
||||
assert torch.allclose(k_cache_torch, k_cache_triton)
|
||||
assert torch.allclose(v_cache_torch, v_cache_triton)
|
||||
|
115
tests/test_infer_ops/triton/test_decoding_attn.py
Normal file
115
tests/test_infer_ops/triton/test_decoding_attn.py
Normal file
@@ -0,0 +1,115 @@
|
||||
import pytest
|
||||
import torch
|
||||
from packaging import version
|
||||
|
||||
from colossalai.kernel.triton import flash_decoding_fwd
|
||||
from colossalai.utils import get_current_device
|
||||
from tests.test_infer_ops.triton.kernel_utils import mock_alloc_block_table_and_kvcache, torch_attn_ref
|
||||
|
||||
try:
|
||||
import triton # noqa
|
||||
|
||||
HAS_TRITON = True
|
||||
except ImportError:
|
||||
HAS_TRITON = False
|
||||
print("please install triton from https://github.com/openai/triton")
|
||||
|
||||
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4")
|
||||
|
||||
|
||||
def torch_decoding(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, context_lengths: torch.Tensor):
|
||||
assert context_lengths.dim() == 1, "context_lengths should be a 1D tensor"
|
||||
assert q.size(1) == 1, "Only used for decoding"
|
||||
assert k.shape == v.shape
|
||||
|
||||
bsz, _, num_heads, head_dim = q.shape
|
||||
_, kv_seq_len, num_kv_heads, _ = k.shape
|
||||
assert num_heads % num_kv_heads == 0, "Invalid kv heads and attention heads."
|
||||
padding_mask = torch.zeros((bsz, 1, 1, kv_seq_len), dtype=torch.float32, device=q.device)
|
||||
for i in range(bsz):
|
||||
cur_seq_len = context_lengths[i].item()
|
||||
assert cur_seq_len <= kv_seq_len
|
||||
padding_mask[i, :, :, : kv_seq_len - cur_seq_len] = float("-inf")
|
||||
|
||||
out = torch_attn_ref(q, k, v, padding_mask, bsz, 1, kv_seq_len, num_heads, num_kv_heads, head_dim)
|
||||
return out
|
||||
|
||||
|
||||
@pytest.mark.skipif(not (HAS_TRITON and TRITON_CUDA_SUPPORT), reason="requires triton")
|
||||
@pytest.mark.parametrize("bsz", [4, 7, 32])
|
||||
@pytest.mark.parametrize("block_size", [16, 32, 64])
|
||||
@pytest.mark.parametrize("max_num_blocks_per_seq", [8, 32])
|
||||
@pytest.mark.parametrize("num_attn_heads", [16])
|
||||
@pytest.mark.parametrize("kv_group_num", [1, 2, 16])
|
||||
@pytest.mark.parametrize("same_context_len", [True, False])
|
||||
def test_flash_decoding(
|
||||
bsz: int,
|
||||
block_size: int,
|
||||
max_num_blocks_per_seq: int,
|
||||
num_attn_heads: int,
|
||||
kv_group_num: int,
|
||||
same_context_len: bool,
|
||||
):
|
||||
torch.manual_seed(123)
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.synchronize()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
num_kv_heads = num_attn_heads // kv_group_num
|
||||
assert isinstance(num_kv_heads, int) and num_kv_heads > 0, "Invalid number of kv heads."
|
||||
q_len = 1
|
||||
head_dim = 128
|
||||
max_seq_len = block_size * max_num_blocks_per_seq
|
||||
dtype = torch.float16
|
||||
device = get_current_device()
|
||||
|
||||
if same_context_len:
|
||||
context_lengths = torch.tensor([max_seq_len for _ in range(bsz)], dtype=torch.int32, device=device)
|
||||
else:
|
||||
context_lengths = torch.randint(low=1, high=max_seq_len, size=(bsz,), dtype=torch.int32, device=device)
|
||||
num_tokens = torch.sum(context_lengths).item()
|
||||
|
||||
q_size = (bsz, q_len, num_attn_heads, head_dim)
|
||||
q = torch.empty(size=q_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5)
|
||||
kv_size = (num_tokens, 2 * num_kv_heads, head_dim)
|
||||
kv = torch.empty(size=kv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5)
|
||||
k, v = torch.split(kv, [num_kv_heads, num_kv_heads], dim=-2)
|
||||
|
||||
cache_shape = (bsz * max_num_blocks_per_seq, num_kv_heads, head_dim, block_size)
|
||||
k_cache = torch.zeros(size=cache_shape, dtype=dtype, device=device)
|
||||
v_cache = torch.zeros(size=cache_shape, dtype=dtype, device=device)
|
||||
# Mock allocation on block tables as well as blocked kv caches
|
||||
block_tables = mock_alloc_block_table_and_kvcache(
|
||||
k, v, k_cache, v_cache, context_lengths, bsz, max_num_blocks_per_seq, block_size
|
||||
)
|
||||
block_tables = block_tables.to(device=device)
|
||||
|
||||
q = q.view(bsz, q_len, num_attn_heads, head_dim)
|
||||
out_triton = flash_decoding_fwd(
|
||||
q,
|
||||
k_cache,
|
||||
v_cache,
|
||||
context_lengths,
|
||||
block_tables,
|
||||
block_size,
|
||||
kv_group_num,
|
||||
)
|
||||
out_triton = out_triton.unsqueeze(1) # [bsz, 1, num_heads, head_dim]
|
||||
|
||||
# rebuild (batched) kv with padding for torch attention
|
||||
# q [bsz, 1, num_heads, head_dim]
|
||||
# k/v [num_tokens, num_kv_heads, head_dim]
|
||||
max_seq_len = context_lengths.max().item()
|
||||
k_torch = torch.zeros((bsz, max_seq_len, num_kv_heads, head_dim), dtype=k.dtype, device=k.device)
|
||||
v_torch = torch.zeros_like(k_torch)
|
||||
prev_len_sum = 0
|
||||
for i, seq_len in enumerate(context_lengths.tolist()):
|
||||
# mock left-side padding
|
||||
k_torch[i, -seq_len:, :, :] = k[prev_len_sum : prev_len_sum + seq_len]
|
||||
v_torch[i, -seq_len:, :, :] = v[prev_len_sum : prev_len_sum + seq_len]
|
||||
prev_len_sum += seq_len
|
||||
# k/v [bsz, max_seq_len, num_kv_heads, head_dim]
|
||||
out_torch = torch_decoding(q, k_torch, v_torch, context_lengths)
|
||||
|
||||
assert out_torch.shape == out_triton.shape
|
||||
assert torch.allclose(out_torch, out_triton, atol=1e-3, rtol=1e-4)
|
Reference in New Issue
Block a user