mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-16 22:52:25 +00:00
[hotfix] polish flash attention (#1802)
This commit is contained in:
@@ -10,20 +10,6 @@ import subprocess
|
||||
|
||||
import torch
|
||||
|
||||
try:
|
||||
import triton
|
||||
import triton.language as tl
|
||||
HAS_TRITON = True
|
||||
except ImportError:
|
||||
print('please install triton from https://github.com/openai/triton')
|
||||
HAS_TRITON = False
|
||||
try:
|
||||
from flash_attn.flash_attn_interface import flash_attn_unpadded_func
|
||||
HAS_FLASH_ATTN = True
|
||||
except ImportError:
|
||||
HAS_FLASH_ATTN = False
|
||||
print('please install flash_attn from https://github.com/HazyResearch/flash-attention')
|
||||
|
||||
|
||||
def triton_check():
|
||||
cuda_home = os.getenv("CUDA_HOME", default="/usr/local/cuda")
|
||||
@@ -38,9 +24,26 @@ def triton_check():
|
||||
return False
|
||||
|
||||
|
||||
TRITON_AVALIABLE = triton_check()
|
||||
try:
|
||||
import triton
|
||||
import triton.language as tl
|
||||
if triton_check():
|
||||
HAS_TRITON = True
|
||||
else:
|
||||
print("triton requires cuda >= 11.4")
|
||||
HAS_TRITON = False
|
||||
except ImportError:
|
||||
print('please install triton from https://github.com/openai/triton')
|
||||
HAS_TRITON = False
|
||||
try:
|
||||
from flash_attn.flash_attn_interface import flash_attn_unpadded_func
|
||||
HAS_FLASH_ATTN = True
|
||||
except ImportError:
|
||||
HAS_FLASH_ATTN = False
|
||||
print('please install flash_attn from https://github.com/HazyResearch/flash-attention')
|
||||
|
||||
if TRITON_AVALIABLE:
|
||||
|
||||
if HAS_TRITON:
|
||||
|
||||
@triton.jit
|
||||
def _fwd_kernel(
|
||||
@@ -394,7 +397,7 @@ if TRITON_AVALIABLE:
|
||||
Return:
|
||||
out: (batch, nheads, seq, headdim)
|
||||
"""
|
||||
if TRITON_AVALIABLE:
|
||||
if HAS_TRITON:
|
||||
return _TritonFlashAttention.apply(q, k, v, sm_scale)
|
||||
else:
|
||||
raise RuntimeError("Triton kernel requires CUDA 11.4+!")
|
||||
|
Reference in New Issue
Block a user