updated flash attention api

This commit is contained in:
zbian
2022-11-14 17:11:33 +08:00
committed by アマデウス
parent 36c0f3ea5b
commit 6877121377
3 changed files with 64 additions and 36 deletions

View File

@@ -6,7 +6,11 @@ from colossalai.kernel.cuda_native.flash_attention import HAS_FLASH_ATTN, HAS_TR
if HAS_FLASH_ATTN:
from colossalai.kernel.cuda_native.flash_attention import (
flash_attention_q_k_v, flash_attention_q_kv, flash_attention_qkv)
MaskedFlashAttention,
flash_attention_q_k_v,
flash_attention_q_kv,
flash_attention_qkv,
)
if HAS_TRITON:
from colossalai.kernel.cuda_native.flash_attention import triton_flash_attention
@@ -87,17 +91,17 @@ def test_flash_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
if i == 0:
tri_dq, tri_dk, tri_dv, = torch.autograd.grad(tri_out, (q, k, v), dout)
tri_out, tri_dq, tri_dk, tri_dv = map(lambda x: rearrange(x, '(z n) h d -> z h n d', z=Z),
(tri_out, tri_dq, tri_dk, tri_dv))
(tri_out, tri_dq, tri_dk, tri_dv))
elif i == 1:
tri_dq, tri_dkv, = torch.autograd.grad(tri_out, (q, kv), dout)
tri_dk, tri_dv = torch.chunk(tri_dkv, 2, dim=1)
tri_out, tri_dq, tri_dk, tri_dv = map(lambda x: rearrange(x, '(z n) h d -> z h n d', z=Z),
(tri_out, tri_dq, tri_dk.squeeze(1), tri_dv.squeeze(1)))
(tri_out, tri_dq, tri_dk.squeeze(1), tri_dv.squeeze(1)))
else:
tri_dqkv, = torch.autograd.grad(tri_out, (qkv), dout)
tri_dq, tri_dk, tri_dv = torch.chunk(tri_dqkv, 3, dim=1)
tri_out, tri_dq, tri_dk, tri_dv = map(lambda x: rearrange(x, '(z n) h d -> z h n d', z=Z),
(tri_out, tri_dq.squeeze(1), tri_dk.squeeze(1), tri_dv.squeeze(1)))
(tri_out, tri_dq.squeeze(1), tri_dk.squeeze(1), tri_dv.squeeze(1)))
# compare
assert torch.allclose(ref_out, tri_out, atol=1e-3)
@@ -106,5 +110,19 @@ def test_flash_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
assert torch.allclose(ref_dq, tri_dq, atol=1e-3)
@pytest.mark.skipif(HAS_FLASH_ATTN == False, reason="flash is not available")
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(3, 4, 2, 16)])
def test_masked_flash_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
attn = MaskedFlashAttention(N_CTX, D_HEAD, 0.1)
qkv = torch.randn((Z, H, 3 * N_CTX * D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
attention_mask = torch.randint(2, (Z, H)).cuda().bool()
out = attn(qkv, attention_mask)
dout = torch.rand_like(out)
out.backward(dout)
if __name__ == '__main__':
test_flash_attention(3, 4, 2, 16)