mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 09:07:51 +00:00
[fix] coloattention support flash attention 2 (#4347)
Improved ColoAttention interface to support flash attention 2. Solved #4322
This commit is contained in:
@@ -4,11 +4,15 @@ import pytest
|
||||
import torch
|
||||
from einops import rearrange
|
||||
|
||||
from colossalai.kernel.cuda_native.flash_attention import HAS_MEM_EFF_ATTN
|
||||
from colossalai.kernel.cuda_native.mha.flash_attn_2 import HAS_FLASH_ATTN
|
||||
from colossalai.kernel.cuda_native.mha.mem_eff_attn import HAS_MEM_EFF_ATTN
|
||||
from colossalai.testing import clear_cache_before_run, parameterize
|
||||
|
||||
if HAS_MEM_EFF_ATTN:
|
||||
from colossalai.kernel.cuda_native.flash_attention import AttnMaskType, ColoAttention
|
||||
if HAS_MEM_EFF_ATTN or HAS_FLASH_ATTN:
|
||||
from colossalai.kernel.cuda_native.mha.mha import ColoAttention
|
||||
from colossalai.kernel.cuda_native.scaled_softmax import AttnMaskType
|
||||
|
||||
DTYPE = [torch.float16, torch.bfloat16, torch.float32]
|
||||
|
||||
|
||||
def baseline_attention(Z, N_CTX, H, q, k, v, sm_scale):
|
||||
@@ -22,10 +26,13 @@ def baseline_attention(Z, N_CTX, H, q, k, v, sm_scale):
|
||||
return ref_out
|
||||
|
||||
|
||||
@pytest.mark.skipif(HAS_MEM_EFF_ATTN == False, reason="xformers is not available")
|
||||
@pytest.mark.skipif(not HAS_MEM_EFF_ATTN and not HAS_FLASH_ATTN, reason="xformers is not available")
|
||||
@clear_cache_before_run()
|
||||
@parameterize('B, S, H, D_HEAD', [(6, 8, 4, 16)])
|
||||
def test_attention_gpt(B, S, H, D_HEAD, dtype=torch.float16):
|
||||
@parameterize('proj_shape', [(1, 8, 4, 16)])
|
||||
@parameterize('dtype', DTYPE)
|
||||
def test_attention_gpt(proj_shape, dtype):
|
||||
# TODO check output value
|
||||
(B, S, H, D_HEAD) = proj_shape
|
||||
D = H * D_HEAD
|
||||
|
||||
c_attn = torch.nn.Linear(D, 3 * D, dtype=dtype, device="cuda")
|
||||
@@ -35,7 +42,11 @@ def test_attention_gpt(B, S, H, D_HEAD, dtype=torch.float16):
|
||||
|
||||
qkv = c_attn(x)
|
||||
q, k, v = rearrange(qkv, 'b s (n h d) -> n b s h d', n=3, h=H)
|
||||
y = attn(q, k, v, attn_mask_type=AttnMaskType.causal)
|
||||
|
||||
mask = [torch.ones(S - i, dtype=dtype, device="cuda") for i in range(B)]
|
||||
mask = torch.nn.utils.rnn.pad_sequence(mask, batch_first=True)
|
||||
|
||||
y = attn(q, k, v, attn_mask=mask, attn_mask_type=AttnMaskType.paddedcausal)
|
||||
|
||||
assert list(y.shape) == [B, S, D]
|
||||
|
||||
@@ -43,10 +54,12 @@ def test_attention_gpt(B, S, H, D_HEAD, dtype=torch.float16):
|
||||
y.backward(dy)
|
||||
|
||||
|
||||
@pytest.mark.skipif(HAS_MEM_EFF_ATTN == False, reason="xformers is not available")
|
||||
@pytest.mark.skipif(not HAS_MEM_EFF_ATTN and not HAS_FLASH_ATTN, reason="xformers is not available")
|
||||
@clear_cache_before_run()
|
||||
@parameterize('B, S, H, D_HEAD', [(6, 8, 4, 16)])
|
||||
def test_attention_bert(B, S, H, D_HEAD, dtype=torch.float16):
|
||||
@parameterize('proj_shape', [(6, 8, 4, 16)])
|
||||
@parameterize('dtype', DTYPE)
|
||||
def test_attention_bert(proj_shape, dtype):
|
||||
(B, S, H, D_HEAD) = proj_shape
|
||||
D = H * D_HEAD
|
||||
|
||||
c_attn = torch.nn.Linear(D, 3 * D, dtype=dtype, device="cuda")
|
||||
@@ -67,10 +80,12 @@ def test_attention_bert(B, S, H, D_HEAD, dtype=torch.float16):
|
||||
y.backward(dy)
|
||||
|
||||
|
||||
@pytest.mark.skipif(HAS_MEM_EFF_ATTN == False, reason="xformers is not available")
|
||||
@pytest.mark.skipif(not HAS_MEM_EFF_ATTN and not HAS_FLASH_ATTN, reason="xformers is not available")
|
||||
@clear_cache_before_run()
|
||||
@parameterize('B, S, H, D_HEAD', [(6, 8, 4, 16)])
|
||||
def test_attention_no_mask(B, S, H, D_HEAD, dtype=torch.float16):
|
||||
@parameterize('proj_shape', [(6, 8, 4, 16)])
|
||||
@parameterize('dtype', DTYPE)
|
||||
def test_attention_no_mask(proj_shape, dtype):
|
||||
(B, S, H, D_HEAD) = proj_shape
|
||||
D = H * D_HEAD
|
||||
|
||||
c_attn = torch.nn.Linear(D, 3 * D, dtype=dtype, device="cuda")
|
||||
@@ -87,10 +102,12 @@ def test_attention_no_mask(B, S, H, D_HEAD, dtype=torch.float16):
|
||||
y.backward(dy)
|
||||
|
||||
|
||||
@pytest.mark.skipif(HAS_MEM_EFF_ATTN == False, reason="xformers is not available")
|
||||
@pytest.mark.skipif(not HAS_MEM_EFF_ATTN and not HAS_FLASH_ATTN, reason="xformers is not available")
|
||||
@clear_cache_before_run()
|
||||
@parameterize('B, S, T, H, D_HEAD', [(6, 24, 8, 4, 16)])
|
||||
def test_cross_attention(B, S, T, H, D_HEAD, dtype=torch.float16):
|
||||
@parameterize('proj_shape', [(6, 24, 8, 4, 16)])
|
||||
@parameterize('dtype', DTYPE)
|
||||
def test_cross_attention(proj_shape, dtype):
|
||||
(B, S, T, H, D_HEAD) = proj_shape
|
||||
D = H * D_HEAD
|
||||
|
||||
q_attn = torch.nn.Linear(D, D, dtype=dtype, device="cuda")
|
||||
|
Reference in New Issue
Block a user