mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 09:07:51 +00:00
updated attention kernel (#2133)
This commit is contained in:
@@ -2,7 +2,7 @@ import pytest
|
||||
import torch
|
||||
from einops import rearrange
|
||||
|
||||
from colossalai.kernel.cuda_native.flash_attention import HAS_FLASH_ATTN, HAS_TRITON
|
||||
from colossalai.kernel.cuda_native.flash_attention import HAS_FLASH_ATTN, HAS_MEM_EFF_ATTN, HAS_TRITON
|
||||
|
||||
if HAS_FLASH_ATTN:
|
||||
from colossalai.kernel.cuda_native.flash_attention import (
|
||||
@@ -15,6 +15,9 @@ if HAS_FLASH_ATTN:
|
||||
if HAS_TRITON:
|
||||
from colossalai.kernel.cuda_native.flash_attention import triton_flash_attention
|
||||
|
||||
if HAS_MEM_EFF_ATTN:
|
||||
from colossalai.kernel.cuda_native.flash_attention import LowerTriangularMask, MemoryEfficientAttention
|
||||
|
||||
|
||||
def baseline_attention(Z, N_CTX, H, q, k, v, sm_scale):
|
||||
M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda"))
|
||||
@@ -124,5 +127,20 @@ def test_masked_flash_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
|
||||
out.backward(dout)
|
||||
|
||||
|
||||
@pytest.mark.skipif(HAS_MEM_EFF_ATTN == False, reason="xformers is not available")
|
||||
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(6, 8, 4, 16)])
|
||||
def test_memory_efficient_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
|
||||
attn = MemoryEfficientAttention(N_CTX * D_HEAD, N_CTX, 0.1)
|
||||
|
||||
q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
|
||||
k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
|
||||
v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
|
||||
|
||||
out = attn(q, k, v, attention_mask=LowerTriangularMask())
|
||||
|
||||
dout = torch.rand_like(out)
|
||||
out.backward(dout)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_flash_attention(3, 4, 2, 16)
|
||||
|
Reference in New Issue
Block a user