ColossalAI/extensions/flash_attention/flash_attention_npu.py
Hongxin Liu 19e1a5cf16
[shardformer] update colo attention to support custom mask (#5510)
* [feature] refactor colo attention (#5462)

* [extension] update api

* [feature] add colo attention

* [feature] update sdpa

* [feature] update npu attention

* [feature] update flash-attn

* [test] add flash attn test

* [test] update flash attn test

* [shardformer] update modeling to fit colo attention (#5465)

* [misc] refactor folder structure

* [shardformer] update llama flash-attn

* [shardformer] fix llama policy

* [devops] update tensornvme install

* [test] update llama test

* [shardformer] update colo attn kernel dispatch

* [shardformer] update blip2

* [shardformer] update chatglm

* [shardformer] update gpt2

* [shardformer] update gptj

* [shardformer] update opt

* [shardformer] update vit

* [shardformer] update colo attention mask prep

* [shardformer] update whisper

* [test] fix shardformer tests (#5514)

* [test] fix shardformer tests

* [test] fix shardformer tests
2024-03-27 11:19:32 +08:00

63 lines
1.9 KiB
Python

from ..base_extension import _Extension
class FlashAttentionNpuExtension(_Extension):
def __init__(self):
super().__init__(name="flash_attention_npu", support_aot=False, support_jit=False)
def is_available(self) -> bool:
try:
import torch_npu
return hasattr(torch_npu, "npu_fusion_attention")
except:
return False
def assert_compatible(self) -> bool:
pass
def build_aot(self) -> None:
raise NotImplementedError(
"Flash Attention NPU does not require ahead-of-time compilation. Please use it by installing torch_npu."
)
def build_jit(self) -> None:
raise NotImplementedError(
"Flash Attention NPU does not require just-in-time compilation. Please use it by installing torch_npu."
)
def load(self):
from typing import Optional
import torch
import torch_npu
def flash_attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
dropout_p: float = 0.0,
scale: Optional[float] = None,
attention_mask: Optional[torch.Tensor] = None,
is_causal: bool = False,
cu_seqlens_q: Optional[torch.Tensor] = None,
cu_seqlens_kv: Optional[torch.Tensor] = None,
max_seqlen_q: Optional[int] = None,
max_seqlen_kv: Optional[int] = None,
q_indices: Optional[torch.Tensor] = None,
kv_indices: Optional[torch.Tensor] = None,
):
num_heads = q.size(1)
return torch_npu.npu_fusion_attention(
q,
k,
v,
num_heads,
"BNSD",
atten_mask=attention_mask.bool(),
scale=scale,
keep_prob=1 - dropout_p,
)[0]
return flash_attention