mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-24 11:08:50 +00:00
[feat] refactored extension module (#5298)
* [feat] refactored extension module * polish * polish * polish * polish * polish * polish * polish * polish * polish * polish
This commit is contained in:
20
extensions/flash_attention/__init__.py
Normal file
20
extensions/flash_attention/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from .flash_attention_dao_cuda import FlashAttentionDaoCudaExtension
|
||||
from .flash_attention_npu import FlashAttentionNpuExtension
|
||||
from .flash_attention_xformers_cuda import FlashAttentionXformersCudaExtension
|
||||
|
||||
try:
|
||||
import flash_attention # noqa
|
||||
|
||||
HAS_FLASH_ATTN = True
|
||||
except:
|
||||
HAS_FLASH_ATTN = False
|
||||
|
||||
try:
|
||||
import xformers # noqa
|
||||
|
||||
HAS_MEM_EFF_ATTN = True
|
||||
except:
|
||||
HAS_MEM_EFF_ATTN = False
|
||||
|
||||
|
||||
__all__ = ["FlashAttentionDaoCudaExtension", "FlashAttentionXformersCudaExtension", "FlashAttentionNpuExtension"]
|
93
extensions/flash_attention/flash_attention_dao_cuda.py
Normal file
93
extensions/flash_attention/flash_attention_dao_cuda.py
Normal file
@@ -0,0 +1,93 @@
|
||||
from ..base_extension import _Extension
|
||||
|
||||
|
||||
class FlashAttentionDaoCudaExtension(_Extension):
|
||||
def __init__(self):
|
||||
super().__init__(name="flash_attention_dao_cuda", support_aot=False, support_jit=False, priority=10)
|
||||
|
||||
def is_hardware_available(self) -> bool:
|
||||
# cuda extension can only be built if cuda is availabe
|
||||
try:
|
||||
import torch
|
||||
|
||||
cuda_available = torch.cuda.is_available()
|
||||
except:
|
||||
cuda_available = False
|
||||
return cuda_available
|
||||
|
||||
def assert_hardware_compatible(self) -> bool:
|
||||
pass
|
||||
|
||||
def build_aot(self) -> None:
|
||||
raise NotImplementedError(
|
||||
"We rely on the third-party flash-attn library for flash attention (https://github.com/Dao-AILab/flash-attention). Please install flash-attn via 'pip install flash-attn --no-build-isolation'."
|
||||
)
|
||||
|
||||
def build_jit(self) -> None:
|
||||
raise NotImplementedError(
|
||||
"We rely on the third-party flash-attn library for flash attention (https://github.com/Dao-AILab/flash-attention). Please install flash-attn via 'pip install flash-attn --no-build-isolation'"
|
||||
)
|
||||
|
||||
def load(self):
|
||||
try:
|
||||
from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_func
|
||||
except ImportError:
|
||||
raise ModuleNotFoundError(
|
||||
(
|
||||
"We rely on the third-party flash-attn library for flash attention. Please install flash-attn via 'pip install flash-attn --no-build-isolation'"
|
||||
)
|
||||
)
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
def flash_attention(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
seq_len_info_q: "SeqLenInfo",
|
||||
seq_len_info_kv: "SeqLenInfo",
|
||||
origin_attn_mask: Optional[torch.Tensor] = None,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
dropout_p: float = 0.0,
|
||||
scale: float = None,
|
||||
causal: bool = False,
|
||||
padded: bool = False,
|
||||
):
|
||||
"""
|
||||
Arguments:
|
||||
q: (batch, q_seqlen, nheads, headdim)
|
||||
k: (batch, kv_seqlen, nheads, headdim)
|
||||
v: (batch, kv_seqlen, nheads, headdim)
|
||||
batch_size: int.
|
||||
seq_len: int.
|
||||
dropout_p: float. Dropout probability.
|
||||
sm_scale: float. The scaling of QK^T before applying softmax.
|
||||
Default to 1 / sqrt(headdim).
|
||||
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
|
||||
Return:
|
||||
attn_out: (batch, q_seqlen, nheads, headdim).
|
||||
"""
|
||||
# check if the input is in allowed dtypes
|
||||
if padded:
|
||||
if seq_len_info_kv == None:
|
||||
seq_len_info_kv = seq_len_info_q
|
||||
|
||||
attn_out = flash_attn_varlen_func(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
seq_len_info_q.cu_seqlens,
|
||||
seq_len_info_kv.cu_seqlens,
|
||||
seq_len_info_q.max_seqlen,
|
||||
seq_len_info_kv.max_seqlen,
|
||||
dropout_p,
|
||||
scale,
|
||||
causal,
|
||||
)
|
||||
else:
|
||||
attn_out = flash_attn_func(q, k, v, dropout_p=dropout_p, softmax_scale=scale, causal=causal)
|
||||
return attn_out
|
||||
|
||||
return flash_attention
|
73
extensions/flash_attention/flash_attention_npu.py
Normal file
73
extensions/flash_attention/flash_attention_npu.py
Normal file
@@ -0,0 +1,73 @@
|
||||
from ..base_extension import _Extension
|
||||
|
||||
|
||||
class FlashAttentionNpuExtension(_Extension):
|
||||
def __init__(self):
|
||||
super().__init__(name="flash_attention_npu", support_aot=False, support_jit=False)
|
||||
|
||||
def is_hardware_available(self) -> bool:
|
||||
try:
|
||||
import torch_npu # noqa
|
||||
|
||||
return True
|
||||
except:
|
||||
return False
|
||||
|
||||
def assert_hardware_compatible(self) -> bool:
|
||||
pass
|
||||
|
||||
def build_aot(self) -> None:
|
||||
raise NotImplementedError(
|
||||
"Flash Attention NPU does not require ahead-of-time compilation. Please use it by installing torch_npu."
|
||||
)
|
||||
|
||||
def build_jit(self) -> None:
|
||||
raise NotImplementedError(
|
||||
"Flash Attention NPU does not require just-in-time compilation. Please use it by installing torch_npu."
|
||||
)
|
||||
|
||||
def load(self):
|
||||
import torch
|
||||
from einops import rearrange
|
||||
|
||||
def npu_sdpa_attention(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
seq_len_info_q=None,
|
||||
seq_len_info_kv=None,
|
||||
origin_attn_mask: torch.Tensor = None,
|
||||
dropout_p: float = 0.0,
|
||||
scale: float = 1.0,
|
||||
causal=None,
|
||||
padded=None,
|
||||
):
|
||||
"""
|
||||
The scaled dot product attention.
|
||||
|
||||
Arguments:
|
||||
q: (batch, q_seqlen, nheads, headdim)
|
||||
k: (batch, kv_seqlen, nheads, headdim)
|
||||
v: (batch, kv_seqlen, nheads, headdim)
|
||||
batch_size: int.
|
||||
seq_len: int.
|
||||
dropout_p: float. Dropout probability.
|
||||
scale: float. The scaling of QK^T before applying softmax.
|
||||
Default to 1.
|
||||
Return:
|
||||
attn_out: (batch, q_seqlen, nheads, headdim).
|
||||
"""
|
||||
q, k, v = [rearrange(x, "b s h d -> b h s d").contiguous() for x in (q, k, v)]
|
||||
output = torch.nn.functional.scaled_dot_product_attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
attn_mask=origin_attn_mask,
|
||||
dropout_p=dropout_p,
|
||||
is_causal=origin_attn_mask is None,
|
||||
scale=scale,
|
||||
)
|
||||
output = rearrange(output, "b h s d -> b s (h d)")
|
||||
return output
|
||||
|
||||
return npu_sdpa_attention
|
94
extensions/flash_attention/flash_attention_xformers_cuda.py
Normal file
94
extensions/flash_attention/flash_attention_xformers_cuda.py
Normal file
@@ -0,0 +1,94 @@
|
||||
from ..base_extension import _Extension
|
||||
|
||||
|
||||
class FlashAttentionXformersCudaExtension(_Extension):
|
||||
def __init__(self):
|
||||
super().__init__(name="flash_attention_xformers_cuda", support_aot=False, support_jit=False)
|
||||
|
||||
def is_hardware_available(self) -> bool:
|
||||
# cuda extension can only be built if cuda is availabe
|
||||
try:
|
||||
import torch
|
||||
|
||||
cuda_available = torch.cuda.is_available()
|
||||
except:
|
||||
cuda_available = False
|
||||
return cuda_available
|
||||
|
||||
def assert_hardware_compatible(self) -> bool:
|
||||
pass
|
||||
|
||||
def build_aot(self) -> None:
|
||||
raise NotImplementedError(
|
||||
"We rely on the third-party xformers library for flash attention (https://github.com/facebookresearch/xformers). Please install xformers according to the GitHub Readme."
|
||||
)
|
||||
|
||||
def build_jit(self) -> None:
|
||||
raise NotImplementedError(
|
||||
"We rely on the third-party xformers library for flash attention (https://github.com/facebookresearch/xformers). Please install xformers according to the GitHub Readme."
|
||||
)
|
||||
|
||||
def load(self):
|
||||
try:
|
||||
from xformers.ops.fmha import MemoryEfficientAttentionCutlassOp, memory_efficient_attention
|
||||
from xformers.ops.fmha.attn_bias import (
|
||||
BlockDiagonalCausalMask,
|
||||
BlockDiagonalMask,
|
||||
LowerTriangularMask,
|
||||
LowerTriangularMaskWithTensorBias,
|
||||
)
|
||||
except ImportError:
|
||||
raise ModuleNotFoundError(
|
||||
(
|
||||
"We rely on the third-party xformers library for flash attention (https://github.com/facebookresearch/xformers). Please install xformers according to the GitHub Readme."
|
||||
)
|
||||
)
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
allow_alibi = True
|
||||
for op in MemoryEfficientAttentionCutlassOp:
|
||||
allow_alibi = allow_alibi & (LowerTriangularMaskWithTensorBias in op.SUPPORTED_ATTN_BIAS_TYPES)
|
||||
|
||||
def mem_eff_attention(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
seq_len_info_q: "SeqLenInfo",
|
||||
seq_len_info_kv: "SeqLenInfo",
|
||||
origin_attn_mask: Optional[torch.Tensor] = None,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
dropout_p: float = 0.0,
|
||||
scale: float = None,
|
||||
causal: bool = False,
|
||||
padded: bool = False,
|
||||
):
|
||||
attn_bias = None
|
||||
if padded: # bert style
|
||||
if not causal:
|
||||
attn_bias = BlockDiagonalMask.from_seqlens(seq_len_info_q.seqlens, seq_len_info_kv.seqlens)
|
||||
else:
|
||||
attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_len_info_q.seqlens, seq_len_info_kv.seqlens)
|
||||
elif causal: # gpt style
|
||||
attn_bias = LowerTriangularMask()
|
||||
|
||||
if bias is not None: # alibi / relative position embedding
|
||||
assert allow_alibi, "flash attention with bias is not supported in this system."
|
||||
assert causal, "attention with bias is only supported for causal attention so far."
|
||||
attn_bias = attn_bias.add_bias(bias)
|
||||
|
||||
if padded:
|
||||
q = q.unsqueeze(0)
|
||||
k = k.unsqueeze(0)
|
||||
v = v.unsqueeze(0)
|
||||
|
||||
out = memory_efficient_attention(q, k, v, attn_bias=attn_bias, p=dropout_p, scale=scale)
|
||||
|
||||
# shape: (b*s, n, d)
|
||||
if padded:
|
||||
out = out.squeeze(0)
|
||||
|
||||
return out
|
||||
|
||||
return mem_eff_attention
|
Reference in New Issue
Block a user