mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-02 18:21:57 +00:00
[hotfix] polish flash attention (#1802)
This commit is contained in:
parent
218c75fd9d
commit
501a9e9cd2
@ -10,20 +10,6 @@ import subprocess
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
try:
|
|
||||||
import triton
|
|
||||||
import triton.language as tl
|
|
||||||
HAS_TRITON = True
|
|
||||||
except ImportError:
|
|
||||||
print('please install triton from https://github.com/openai/triton')
|
|
||||||
HAS_TRITON = False
|
|
||||||
try:
|
|
||||||
from flash_attn.flash_attn_interface import flash_attn_unpadded_func
|
|
||||||
HAS_FLASH_ATTN = True
|
|
||||||
except ImportError:
|
|
||||||
HAS_FLASH_ATTN = False
|
|
||||||
print('please install flash_attn from https://github.com/HazyResearch/flash-attention')
|
|
||||||
|
|
||||||
|
|
||||||
def triton_check():
|
def triton_check():
|
||||||
cuda_home = os.getenv("CUDA_HOME", default="/usr/local/cuda")
|
cuda_home = os.getenv("CUDA_HOME", default="/usr/local/cuda")
|
||||||
@ -38,9 +24,26 @@ def triton_check():
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
TRITON_AVALIABLE = triton_check()
|
try:
|
||||||
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
|
if triton_check():
|
||||||
|
HAS_TRITON = True
|
||||||
|
else:
|
||||||
|
print("triton requires cuda >= 11.4")
|
||||||
|
HAS_TRITON = False
|
||||||
|
except ImportError:
|
||||||
|
print('please install triton from https://github.com/openai/triton')
|
||||||
|
HAS_TRITON = False
|
||||||
|
try:
|
||||||
|
from flash_attn.flash_attn_interface import flash_attn_unpadded_func
|
||||||
|
HAS_FLASH_ATTN = True
|
||||||
|
except ImportError:
|
||||||
|
HAS_FLASH_ATTN = False
|
||||||
|
print('please install flash_attn from https://github.com/HazyResearch/flash-attention')
|
||||||
|
|
||||||
if TRITON_AVALIABLE:
|
|
||||||
|
if HAS_TRITON:
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def _fwd_kernel(
|
def _fwd_kernel(
|
||||||
@ -394,7 +397,7 @@ if TRITON_AVALIABLE:
|
|||||||
Return:
|
Return:
|
||||||
out: (batch, nheads, seq, headdim)
|
out: (batch, nheads, seq, headdim)
|
||||||
"""
|
"""
|
||||||
if TRITON_AVALIABLE:
|
if HAS_TRITON:
|
||||||
return _TritonFlashAttention.apply(q, k, v, sm_scale)
|
return _TritonFlashAttention.apply(q, k, v, sm_scale)
|
||||||
else:
|
else:
|
||||||
raise RuntimeError("Triton kernel requires CUDA 11.4+!")
|
raise RuntimeError("Triton kernel requires CUDA 11.4+!")
|
||||||
|
@ -2,7 +2,7 @@ import pytest
|
|||||||
import torch
|
import torch
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
|
|
||||||
from colossalai.kernel.cuda_native.flash_attention import HAS_FLASH_ATTN, HAS_TRITON, TRITON_AVALIABLE
|
from colossalai.kernel.cuda_native.flash_attention import HAS_FLASH_ATTN, HAS_TRITON
|
||||||
|
|
||||||
if HAS_FLASH_ATTN:
|
if HAS_FLASH_ATTN:
|
||||||
from colossalai.kernel.cuda_native.flash_attention import flash_attention
|
from colossalai.kernel.cuda_native.flash_attention import flash_attention
|
||||||
@ -22,7 +22,7 @@ def baseline_attention(Z, N_CTX, H, q, k, v, sm_scale):
|
|||||||
return ref_out
|
return ref_out
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(HAS_FLASH_ATTN == False, reason="triton is not available")
|
@pytest.mark.skipif(HAS_FLASH_ATTN == False, reason="flash is not available")
|
||||||
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(3, 2, 16, 8)])
|
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(3, 2, 16, 8)])
|
||||||
def test_triton_flash_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
|
def test_triton_flash_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
|
||||||
torch.manual_seed(20)
|
torch.manual_seed(20)
|
||||||
@ -39,7 +39,7 @@ def test_triton_flash_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
|
|||||||
ref_dq, q.grad = q.grad.clone(), None
|
ref_dq, q.grad = q.grad.clone(), None
|
||||||
|
|
||||||
# triton implementation
|
# triton implementation
|
||||||
if TRITON_AVALIABLE:
|
if HAS_TRITON:
|
||||||
tri_out = triton_flash_attention(q, k, v, sm_scale)
|
tri_out = triton_flash_attention(q, k, v, sm_scale)
|
||||||
tri_out.backward(dout)
|
tri_out.backward(dout)
|
||||||
tri_dv, v.grad = v.grad.clone(), None
|
tri_dv, v.grad = v.grad.clone(), None
|
||||||
@ -59,7 +59,7 @@ def test_triton_flash_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
|
|||||||
raise TypeError("Error type not match!")
|
raise TypeError("Error type not match!")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(HAS_FLASH_ATTN == False, reason="triton is not available")
|
@pytest.mark.skipif(HAS_FLASH_ATTN == False, reason="flash is not available")
|
||||||
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(3, 2, 16, 8)])
|
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(3, 2, 16, 8)])
|
||||||
def test_flash_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
|
def test_flash_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
|
||||||
torch.manual_seed(20)
|
torch.manual_seed(20)
|
||||||
|
Loading…
Reference in New Issue
Block a user