[shardformer] update colo attention to support custom mask (#5510)

* [feature] refactor colo attention (#5462)

* [extension] update api

* [feature] add colo attention

* [feature] update sdpa

* [feature] update npu attention

* [feature] update flash-attn

* [test] add flash attn test

* [test] update flash attn test

* [shardformer] update modeling to fit colo attention (#5465)

* [misc] refactor folder structure

* [shardformer] update llama flash-attn

* [shardformer] fix llama policy

* [devops] update tensornvme install

* [test] update llama test

* [shardformer] update colo attn kernel dispatch

* [shardformer] update blip2

* [shardformer] update chatglm

* [shardformer] update gpt2

* [shardformer] update gptj

* [shardformer] update opt

* [shardformer] update vit

* [shardformer] update colo attention mask prep

* [shardformer] update whisper

* [test] fix shardformer tests (#5514)

* [test] fix shardformer tests

* [test] fix shardformer tests
This commit is contained in:
Hongxin Liu
2024-03-27 11:19:32 +08:00
committed by GitHub
parent 9a3321e9f4
commit 19e1a5cf16
45 changed files with 2543 additions and 1170 deletions

View File

@@ -101,13 +101,13 @@ class MyExtension(_Extension):
self._support_jit = True
self.priority = 10
def is_hardware_available(self) -> bool:
def is_available(self) -> bool:
"""
Return if the required hardware can be found.
"""
...
def assert_hardware_compatible(self) -> None:
def assert_compatible(self) -> None:
"""
Check if the hardware required by the kernel is compatible.
"""

View File

@@ -1,9 +1,5 @@
from .cpu_adam import CpuAdamArmExtension, CpuAdamX86Extension
from .flash_attention import (
FlashAttentionDaoCudaExtension,
FlashAttentionNpuExtension,
FlashAttentionXformersCudaExtension,
)
from .flash_attention import FlashAttentionDaoCudaExtension, FlashAttentionNpuExtension, FlashAttentionSdpaCudaExtension
from .layernorm import LayerNormCudaExtension
from .moe import MoeCudaExtension
from .optimizer import FusedOptimizerCudaExtension
@@ -18,7 +14,7 @@ ALL_EXTENSIONS = [
ScaledMaskedSoftmaxCudaExtension,
ScaledUpperTriangleMaskedSoftmaxCudaExtension,
FlashAttentionDaoCudaExtension,
FlashAttentionXformersCudaExtension,
FlashAttentionSdpaCudaExtension,
FlashAttentionNpuExtension,
]
@@ -31,6 +27,6 @@ __all__ = [
"ScaledMaskedSoftmaxCudaExtension",
"ScaledUpperTriangleMaskedSoftmaxCudaExtension",
"FlashAttentionDaoCudaExtension",
"FlashAttentionXformersCudaExtension",
"FlashAttentionSdpaCudaExtension",
"FlashAttentionNpuExtension",
]

View File

@@ -58,13 +58,13 @@ class _Extension(ABC):
return cache_directory
@abstractmethod
def is_hardware_available(self) -> bool:
def is_available(self) -> bool:
"""
Check if the hardware required by the kernel is available.
"""
@abstractmethod
def assert_hardware_compatible(self) -> None:
def assert_compatible(self) -> None:
"""
Check if the hardware required by the kernel is compatible.
"""

View File

@@ -7,11 +7,11 @@ class CpuAdamArmExtension(_CppExtension):
def __init__(self):
super().__init__(name="cpu_adam_arm")
def is_hardware_available(self) -> bool:
def is_available(self) -> bool:
# only arm allowed
return platform.machine() == "aarch64"
def assert_hardware_compatible(self) -> None:
def assert_compatible(self) -> None:
arch = platform.machine()
assert (
arch == "aarch64"

View File

@@ -8,15 +8,15 @@ class CpuAdamX86Extension(_CudaExtension):
def __init__(self):
super().__init__(name="cpu_adam_x86")
def is_hardware_available(self) -> bool:
return platform.machine() == "x86_64" and super().is_hardware_available()
def is_available(self) -> bool:
return platform.machine() == "x86_64" and super().is_available()
def assert_hardware_compatible(self) -> None:
def assert_compatible(self) -> None:
arch = platform.machine()
assert (
arch == "x86_64"
), f"[extension] The {self.name} kernel requires the CPU architecture to be x86_64 but got {arch}"
super().assert_hardware_compatible()
super().assert_compatible()
# necessary 4 functions
def sources_files(self):

View File

@@ -22,7 +22,7 @@ class _CudaExtension(_CppExtension):
This function should return a list of nvcc compilation flags for extensions.
"""
def is_hardware_available(self) -> bool:
def is_available(self) -> bool:
# cuda extension can only be built if cuda is available
try:
import torch
@@ -32,7 +32,7 @@ class _CudaExtension(_CppExtension):
cuda_available = False
return cuda_available
def assert_hardware_compatible(self) -> None:
def assert_compatible(self) -> None:
from torch.utils.cpp_extension import CUDA_HOME
if not CUDA_HOME:

View File

@@ -1,20 +1,14 @@
from .flash_attention_dao_cuda import FlashAttentionDaoCudaExtension
from .flash_attention_npu import FlashAttentionNpuExtension
from .flash_attention_xformers_cuda import FlashAttentionXformersCudaExtension
from .flash_attention_sdpa_cuda import FlashAttentionSdpaCudaExtension
try:
# TODO: remove this after updating openmoe example
import flash_attention # noqa
HAS_FLASH_ATTN = True
except:
HAS_FLASH_ATTN = False
try:
import xformers # noqa
HAS_MEM_EFF_ATTN = True
except:
HAS_MEM_EFF_ATTN = False
__all__ = ["FlashAttentionDaoCudaExtension", "FlashAttentionXformersCudaExtension", "FlashAttentionNpuExtension"]
__all__ = ["FlashAttentionDaoCudaExtension", "FlashAttentionSdpaCudaExtension", "FlashAttentionNpuExtension"]

View File

@@ -5,17 +5,20 @@ class FlashAttentionDaoCudaExtension(_Extension):
def __init__(self):
super().__init__(name="flash_attention_dao_cuda", support_aot=False, support_jit=False, priority=10)
def is_hardware_available(self) -> bool:
def is_available(self) -> bool:
# cuda extension can only be built if cuda is available
try:
import torch
from flash_attn import flash_attn_func, flash_attn_varlen_kvpacked_func # noqa
from flash_attn.bert_padding import index_first_axis, pad_input # noqa
cuda_available = torch.cuda.is_available()
except:
cuda_available = False
return cuda_available
def assert_hardware_compatible(self) -> bool:
def assert_compatible(self) -> bool:
pass
def build_aot(self) -> None:
@@ -29,65 +32,65 @@ class FlashAttentionDaoCudaExtension(_Extension):
)
def load(self):
try:
from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_func
except ImportError:
raise ModuleNotFoundError(
(
"We rely on the third-party flash-attn library for flash attention. Please install flash-attn via 'pip install flash-attn --no-build-isolation'"
)
)
from typing import Optional
import torch
from einops import rearrange
from flash_attn import flash_attn_func, flash_attn_varlen_kvpacked_func
from flash_attn.bert_padding import index_first_axis, pad_input
def _unpad_input(hidden_states: torch.Tensor, indices: torch.Tensor):
return index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices)
def flash_attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
seq_len_info_q: "SeqLenInfo",
seq_len_info_kv: "SeqLenInfo",
origin_attn_mask: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
scale: float = None,
causal: bool = False,
padded: bool = False,
scale: Optional[float] = None,
attention_mask: Optional[torch.Tensor] = None,
is_causal: bool = False,
cu_seqlens_q: Optional[torch.Tensor] = None,
cu_seqlens_kv: Optional[torch.Tensor] = None,
max_seqlen_q: Optional[int] = None,
max_seqlen_kv: Optional[int] = None,
q_indices: Optional[torch.Tensor] = None,
kv_indices: Optional[torch.Tensor] = None,
):
"""
Arguments:
q: (batch, q_seqlen, nheads, headdim)
k: (batch, kv_seqlen, nheads, headdim)
v: (batch, kv_seqlen, nheads, headdim)
batch_size: int.
seq_len: int.
dropout_p: float. Dropout probability.
sm_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Return:
attn_out: (batch, q_seqlen, nheads, headdim).
"""
# check if the input is in allowed dtypes
if padded:
if seq_len_info_kv == None:
seq_len_info_kv = seq_len_info_q
attn_out = flash_attn_varlen_func(
# [B, N, S, D] -> [B, S, N, D]
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
b, s_q = q.shape[:2]
if cu_seqlens_q is not None:
# padded / padded causal
# unpad input: [B, S, N, D] -> [T, N, D]
q = _unpad_input(q, q_indices)
kv = _unpad_input(torch.stack(tensors=(k, v), dim=2), kv_indices)
attn_output = flash_attn_varlen_kvpacked_func(
q,
kv,
cu_seqlens_q,
cu_seqlens_kv,
max_seqlen_q,
max_seqlen_kv,
dropout_p=dropout_p,
softmax_scale=scale,
causal=is_causal,
)
# pad output: [T, N, D] -> [B, S, N, D]
attn_output = pad_input(attn_output, q_indices, b, s_q)
else:
# causal / no attn mask
attn_output = flash_attn_func(
q,
k,
v,
seq_len_info_q.cu_seqlens,
seq_len_info_kv.cu_seqlens,
seq_len_info_q.max_seqlen,
seq_len_info_kv.max_seqlen,
dropout_p,
scale,
causal,
dropout_p=dropout_p,
softmax_scale=scale,
causal=is_causal,
)
else:
attn_out = flash_attn_func(q, k, v, dropout_p=dropout_p, softmax_scale=scale, causal=causal)
return attn_out
# [B, S, N, D] -> [B, N, S, D]
return attn_output.transpose(1, 2)
return flash_attention

View File

@@ -5,15 +5,15 @@ class FlashAttentionNpuExtension(_Extension):
def __init__(self):
super().__init__(name="flash_attention_npu", support_aot=False, support_jit=False)
def is_hardware_available(self) -> bool:
def is_available(self) -> bool:
try:
import torch_npu # noqa
import torch_npu
return True
return hasattr(torch_npu, "npu_fusion_attention")
except:
return False
def assert_hardware_compatible(self) -> bool:
def assert_compatible(self) -> bool:
pass
def build_aot(self) -> None:
@@ -27,47 +27,36 @@ class FlashAttentionNpuExtension(_Extension):
)
def load(self):
import torch
from einops import rearrange
from typing import Optional
def npu_sdpa_attention(
import torch
import torch_npu
def flash_attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
seq_len_info_q=None,
seq_len_info_kv=None,
origin_attn_mask: torch.Tensor = None,
dropout_p: float = 0.0,
scale: float = 1.0,
causal=None,
padded=None,
scale: Optional[float] = None,
attention_mask: Optional[torch.Tensor] = None,
is_causal: bool = False,
cu_seqlens_q: Optional[torch.Tensor] = None,
cu_seqlens_kv: Optional[torch.Tensor] = None,
max_seqlen_q: Optional[int] = None,
max_seqlen_kv: Optional[int] = None,
q_indices: Optional[torch.Tensor] = None,
kv_indices: Optional[torch.Tensor] = None,
):
"""
The scaled dot product attention.
Arguments:
q: (batch, q_seqlen, nheads, headdim)
k: (batch, kv_seqlen, nheads, headdim)
v: (batch, kv_seqlen, nheads, headdim)
batch_size: int.
seq_len: int.
dropout_p: float. Dropout probability.
scale: float. The scaling of QK^T before applying softmax.
Default to 1.
Return:
attn_out: (batch, q_seqlen, nheads, headdim).
"""
q, k, v = [rearrange(x, "b s h d -> b h s d").contiguous() for x in (q, k, v)]
output = torch.nn.functional.scaled_dot_product_attention(
num_heads = q.size(1)
return torch_npu.npu_fusion_attention(
q,
k,
v,
attn_mask=origin_attn_mask,
dropout_p=dropout_p,
is_causal=origin_attn_mask is None,
num_heads,
"BNSD",
atten_mask=attention_mask.bool(),
scale=scale,
)
output = rearrange(output, "b h s d -> b s (h d)")
return output
keep_prob=1 - dropout_p,
)[0]
return npu_sdpa_attention
return flash_attention

View File

@@ -0,0 +1,56 @@
from ..base_extension import _Extension
class FlashAttentionSdpaCudaExtension(_Extension):
def __init__(self):
super().__init__(name="flash_attention_sdpa_cuda", support_aot=False, support_jit=False)
def is_available(self) -> bool:
# cuda extension can only be built if cuda is available
try:
import torch
cuda_available = torch.cuda.is_available()
except:
cuda_available = False
return cuda_available
def assert_compatible(self) -> bool:
pass
def build_aot(self) -> None:
raise NotImplementedError("Flash attention SDPA does not require ahead-of-time compilation.")
def build_jit(self) -> None:
raise NotImplementedError("Flash attention SDPA does not require just-in-time compilation.")
def load(self):
from typing import Optional
import torch
def flash_attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
dropout_p: float = 0.0,
scale: Optional[float] = None,
attention_mask: Optional[torch.Tensor] = None,
is_causal: bool = False,
cu_seqlens_q: Optional[torch.Tensor] = None,
cu_seqlens_kv: Optional[torch.Tensor] = None,
max_seqlen_q: Optional[int] = None,
max_seqlen_kv: Optional[int] = None,
q_indices: Optional[torch.Tensor] = None,
kv_indices: Optional[torch.Tensor] = None,
):
return torch.nn.functional.scaled_dot_product_attention(
q,
k,
v,
attn_mask=attention_mask,
dropout_p=dropout_p,
scale=scale,
)
return flash_attention

View File

@@ -1,94 +0,0 @@
from ..base_extension import _Extension
class FlashAttentionXformersCudaExtension(_Extension):
def __init__(self):
super().__init__(name="flash_attention_xformers_cuda", support_aot=False, support_jit=False)
def is_hardware_available(self) -> bool:
# cuda extension can only be built if cuda is available
try:
import torch
cuda_available = torch.cuda.is_available()
except:
cuda_available = False
return cuda_available
def assert_hardware_compatible(self) -> bool:
pass
def build_aot(self) -> None:
raise NotImplementedError(
"We rely on the third-party xformers library for flash attention (https://github.com/facebookresearch/xformers). Please install xformers according to the GitHub Readme."
)
def build_jit(self) -> None:
raise NotImplementedError(
"We rely on the third-party xformers library for flash attention (https://github.com/facebookresearch/xformers). Please install xformers according to the GitHub Readme."
)
def load(self):
try:
from xformers.ops.fmha import MemoryEfficientAttentionCutlassOp, memory_efficient_attention
from xformers.ops.fmha.attn_bias import (
BlockDiagonalCausalMask,
BlockDiagonalMask,
LowerTriangularMask,
LowerTriangularMaskWithTensorBias,
)
except ImportError:
raise ModuleNotFoundError(
(
"We rely on the third-party xformers library for flash attention (https://github.com/facebookresearch/xformers). Please install xformers according to the GitHub Readme."
)
)
from typing import Optional
import torch
allow_alibi = True
for op in MemoryEfficientAttentionCutlassOp:
allow_alibi = allow_alibi & (LowerTriangularMaskWithTensorBias in op.SUPPORTED_ATTN_BIAS_TYPES)
def mem_eff_attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
seq_len_info_q: "SeqLenInfo",
seq_len_info_kv: "SeqLenInfo",
origin_attn_mask: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
scale: float = None,
causal: bool = False,
padded: bool = False,
):
attn_bias = None
if padded: # bert style
if not causal:
attn_bias = BlockDiagonalMask.from_seqlens(seq_len_info_q.seqlens, seq_len_info_kv.seqlens)
else:
attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_len_info_q.seqlens, seq_len_info_kv.seqlens)
elif causal: # gpt style
attn_bias = LowerTriangularMask()
if bias is not None: # alibi / relative position embedding
assert allow_alibi, "flash attention with bias is not supported in this system."
assert causal, "attention with bias is only supported for causal attention so far."
attn_bias = attn_bias.add_bias(bias)
if padded:
q = q.unsqueeze(0)
k = k.unsqueeze(0)
v = v.unsqueeze(0)
out = memory_efficient_attention(q, k, v, attn_bias=attn_bias, p=dropout_p, scale=scale)
# shape: (b*s, n, d)
if padded:
out = out.squeeze(0)
return out
return mem_eff_attention