diff --git a/colossalai/kernel/cuda_native/flash_attention.py b/colossalai/kernel/cuda_native/flash_attention.py index d037b89f8..33380b8fc 100644 --- a/colossalai/kernel/cuda_native/flash_attention.py +++ b/colossalai/kernel/cuda_native/flash_attention.py @@ -11,7 +11,7 @@ import subprocess import torch -def triton_check(): +def triton_cuda_check(): cuda_home = os.getenv("CUDA_HOME", default="/usr/local/cuda") cuda_version = subprocess.check_output([os.path.join(cuda_home, "bin/nvcc"), "--version"]).decode().strip() cuda_version = cuda_version.split('release ')[1] @@ -27,7 +27,7 @@ def triton_check(): try: import triton import triton.language as tl - if triton_check(): + if triton_cuda_check(): HAS_TRITON = True else: print("triton requires cuda >= 11.4") @@ -36,7 +36,11 @@ except ImportError: print('please install triton from https://github.com/openai/triton') HAS_TRITON = False try: - from flash_attn.flash_attn_interface import flash_attn_unpadded_func + from flash_attn.flash_attn_interface import ( + flash_attn_unpadded_func, + flash_attn_unpadded_kvpacked_func, + flash_attn_unpadded_qkvpacked_func, + ) HAS_FLASH_ATTN = True except ImportError: HAS_FLASH_ATTN = False @@ -405,12 +409,63 @@ if HAS_TRITON: if HAS_FLASH_ATTN: - def flash_attention(q, k, v, sm_scale, batch_size, seq_len, dropout_p=0., causal=True): + def flash_attention_qkv(qkv, sm_scale, batch_size, seq_len, dropout_p=0., causal=False): """ Arguments: - q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch. - k: (total_k, nheads, headdim), where total_k = total number of key tokens in the batch. - v: (total_k, nheads, headdim), where total_k = total number of key tokens in the batch. + qkv: (batch * seqlen, 3, nheads, headdim) + batch_size: int. + seq_len: int. + sm_scale: float. The scaling of QK^T before applying softmax. + Default to 1 / sqrt(headdim). + dropout_p: float. + causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). + Return: + out: (total, nheads, headdim). + """ + max_s = seq_len + cu_seqlens = torch.arange(0, (batch_size + 1) * seq_len, step=seq_len, dtype=torch.int32, + device=qkv.device) + out = flash_attn_unpadded_qkvpacked_func( + qkv, cu_seqlens, max_s, dropout_p, + softmax_scale=sm_scale, causal=causal + ) + return out + + + def flash_attention_q_kv(q, kv, sm_scale, batch_size, q_seqlen, kv_seqlen, dropout_p=0., causal=False): + """ + Arguments: + q: (batch * q_seqlen, nheads, headdim) + kv: (batch * kv_seqlen, 2, nheads, headdim) + batch_size: int. + seq_len: int. + sm_scale: float. The scaling of QK^T before applying softmax. + Default to 1 / sqrt(headdim). + dropout_p: float. + causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). + Return: + out: (total, nheads, headdim). + """ + cu_seqlens_q = torch.arange(0, (batch_size + 1) * q_seqlen, step=q_seqlen, dtype=torch.int32, device=q.device) + cu_seqlens_k = torch.arange(0, (batch_size + 1) * kv_seqlen, step=kv_seqlen, dtype=torch.int32, device=kv.device) + out = flash_attn_unpadded_kvpacked_func(q, + kv, + cu_seqlens_q, + cu_seqlens_k, + q_seqlen, + kv_seqlen, + dropout_p, + sm_scale, + causal) + return out + + + def flash_attention_q_k_v(q, k, v, sm_scale, batch_size, q_seqlen, kv_seqlen, dropout_p=0., causal=False): + """ + Arguments: + q: (batch * q_seqlen, nheads, headdim) + k: (batch * kv_seqlen, nheads, headdim) + v: (batch * kv_seqlen, nheads, headdim) batch_size: int. seq_len: int. dropout_p: float. Dropout probability. @@ -420,16 +475,15 @@ if HAS_FLASH_ATTN: Return: out: (total, nheads, headdim). """ - lengths = torch.full((batch_size,), fill_value=seq_len, device=q.device) - cu_seqlens = torch.zeros((batch_size + 1,), device=q.device, dtype=torch.int32) - cu_seqlens[1:] = lengths.cumsum(0) + cu_seqlens_q = torch.arange(0, (batch_size + 1) * q_seqlen, step=q_seqlen, dtype=torch.int32, device=q.device) + cu_seqlens_kv = torch.arange(0, (batch_size + 1) * kv_seqlen, step=kv_seqlen, dtype=torch.int32, device=k.device) return flash_attn_unpadded_func(q, k, v, - cu_seqlens_q=cu_seqlens, - cu_seqlens_k=cu_seqlens, - max_seqlen_q=seq_len, - max_seqlen_k=seq_len, - dropout_p=dropout_p, - softmax_scale=sm_scale, - causal=causal) + cu_seqlens_q, + cu_seqlens_kv, + q_seqlen, + kv_seqlen, + dropout_p, + sm_scale, + causal) diff --git a/tests/test_utils/test_flash_attention.py b/tests/test_utils/test_flash_attention.py index 195de0d28..d2409fc62 100644 --- a/tests/test_utils/test_flash_attention.py +++ b/tests/test_utils/test_flash_attention.py @@ -5,7 +5,8 @@ from einops import rearrange from colossalai.kernel.cuda_native.flash_attention import HAS_FLASH_ATTN, HAS_TRITON if HAS_FLASH_ATTN: - from colossalai.kernel.cuda_native.flash_attention import flash_attention + from colossalai.kernel.cuda_native.flash_attention import ( + flash_attention_q_k_v, flash_attention_q_kv, flash_attention_qkv) if HAS_TRITON: from colossalai.kernel.cuda_native.flash_attention import triton_flash_attention @@ -22,8 +23,8 @@ def baseline_attention(Z, N_CTX, H, q, k, v, sm_scale): return ref_out -@pytest.mark.skipif(HAS_FLASH_ATTN == False, reason="flash is not available") -@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(3, 2, 16, 8)]) +@pytest.mark.skipif(HAS_TRITON == False, reason="triton is not available") +@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(3, 4, 2, 16)]) def test_triton_flash_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16): torch.manual_seed(20) q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_() @@ -39,28 +40,20 @@ def test_triton_flash_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16): ref_dq, q.grad = q.grad.clone(), None # triton implementation - if HAS_TRITON: - tri_out = triton_flash_attention(q, k, v, sm_scale) - tri_out.backward(dout) - tri_dv, v.grad = v.grad.clone(), None - tri_dk, k.grad = k.grad.clone(), None - tri_dq, q.grad = q.grad.clone(), None - # compare - assert torch.allclose(ref_out, tri_out, atol=1e-3) - assert torch.allclose(ref_dv, tri_dv, atol=1e-3) - assert torch.allclose(ref_dk, tri_dk, atol=1e-3) - assert torch.allclose(ref_dq, tri_dq, atol=1e-3) - else: - try: - tri_out = flash_attention(q, k, v, sm_scale, Z, N_CTX) - except RuntimeError: - pass - else: - raise TypeError("Error type not match!") + tri_out = triton_flash_attention(q, k, v, sm_scale) + tri_out.backward(dout) + tri_dv, v.grad = v.grad.clone(), None + tri_dk, k.grad = k.grad.clone(), None + tri_dq, q.grad = q.grad.clone(), None + # compare + assert torch.allclose(ref_out, tri_out, atol=1e-3) + assert torch.allclose(ref_dv, tri_dv, atol=1e-3) + assert torch.allclose(ref_dk, tri_dk, atol=1e-3) + assert torch.allclose(ref_dq, tri_dq, atol=1e-3) @pytest.mark.skipif(HAS_FLASH_ATTN == False, reason="flash is not available") -@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(3, 2, 16, 8)]) +@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(3, 4, 2, 16)]) def test_flash_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16): torch.manual_seed(20) q = torch.randn((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_() @@ -78,15 +71,40 @@ def test_flash_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16): # flash implementation q, k, v = map(lambda x: rearrange(x, 'z h n d -> (z n) h d'), [q, k, v]) - tri_out = flash_attention(q, k, v, sm_scale, Z, N_CTX) dout = rearrange(dout, 'z h n d -> (z n) h d').detach() - tri_out.backward(dout, retain_graph=True) - tri_dq, tri_dk, tri_dv, = torch.autograd.grad(tri_out, (q, k, v), dout) - tri_out, tri_dq, tri_dk, tri_dv = map(lambda x: rearrange(x, '(z n) h d -> z h n d', z=Z), - (tri_out, tri_dq, tri_dk, tri_dv)) + for i in range(3): + if i == 0: + tri_out = flash_attention_q_k_v(q, k, v, sm_scale, Z, N_CTX, N_CTX, causal=True) + elif i == 1: + kv = torch.cat((k.unsqueeze(1), v.unsqueeze(1)), dim=1) + tri_out = flash_attention_q_kv(q, kv, sm_scale, Z, N_CTX, N_CTX, causal=True) + else: + qkv = torch.cat((q.unsqueeze(1), k.unsqueeze(1), v.unsqueeze(1)), dim=1) + tri_out = flash_attention_qkv(qkv, sm_scale, Z, N_CTX, causal=True) - # compare - assert torch.allclose(ref_out, tri_out, atol=1e-3) - assert torch.allclose(ref_dv, tri_dv, atol=1e-3) - assert torch.allclose(ref_dk, tri_dk, atol=1e-3) - assert torch.allclose(ref_dq, tri_dq, atol=1e-3) + tri_out.backward(dout, retain_graph=True) + + if i == 0: + tri_dq, tri_dk, tri_dv, = torch.autograd.grad(tri_out, (q, k, v), dout) + tri_out, tri_dq, tri_dk, tri_dv = map(lambda x: rearrange(x, '(z n) h d -> z h n d', z=Z), + (tri_out, tri_dq, tri_dk, tri_dv)) + elif i == 1: + tri_dq, tri_dkv, = torch.autograd.grad(tri_out, (q, kv), dout) + tri_dk, tri_dv = torch.chunk(tri_dkv, 2, dim=1) + tri_out, tri_dq, tri_dk, tri_dv = map(lambda x: rearrange(x, '(z n) h d -> z h n d', z=Z), + (tri_out, tri_dq, tri_dk.squeeze(1), tri_dv.squeeze(1))) + else: + tri_dqkv, = torch.autograd.grad(tri_out, (qkv), dout) + tri_dq, tri_dk, tri_dv = torch.chunk(tri_dqkv, 3, dim=1) + tri_out, tri_dq, tri_dk, tri_dv = map(lambda x: rearrange(x, '(z n) h d -> z h n d', z=Z), + (tri_out, tri_dq.squeeze(1), tri_dk.squeeze(1), tri_dv.squeeze(1))) + + # compare + assert torch.allclose(ref_out, tri_out, atol=1e-3) + assert torch.allclose(ref_dv, tri_dv, atol=1e-3) + assert torch.allclose(ref_dk, tri_dk, atol=1e-3) + assert torch.allclose(ref_dq, tri_dq, atol=1e-3) + + +if __name__ == '__main__': + test_flash_attention(3, 4, 2, 16)