mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 01:06:00 +00:00
updated attention kernel (#2133)
This commit is contained in:
@@ -48,6 +48,13 @@ except ImportError:
|
||||
HAS_FLASH_ATTN = False
|
||||
print('please install flash_attn from https://github.com/HazyResearch/flash-attention')
|
||||
|
||||
try:
|
||||
from xformers.ops.fmha import memory_efficient_attention
|
||||
HAS_MEM_EFF_ATTN = True
|
||||
except ImportError:
|
||||
HAS_MEM_EFF_ATTN = False
|
||||
print('please install xformers from https://github.com/facebookresearch/xformers')
|
||||
|
||||
if HAS_TRITON:
|
||||
|
||||
@triton.jit
|
||||
@@ -497,3 +504,22 @@ if HAS_FLASH_ATTN:
|
||||
device=k.device)
|
||||
return flash_attn_unpadded_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, q_seqlen, kv_seqlen, dropout_p, sm_scale,
|
||||
causal)
|
||||
|
||||
|
||||
if HAS_MEM_EFF_ATTN:
|
||||
|
||||
from einops import rearrange
|
||||
from xformers.ops.fmha import LowerTriangularMask
|
||||
|
||||
class MemoryEfficientAttention(torch.nn.Module):
|
||||
|
||||
def __init__(self, hidden_size: int, num_attention_heads: int, attention_dropout: float = 0.0):
|
||||
super().__init__()
|
||||
attention_head_size = hidden_size // num_attention_heads
|
||||
self.scale = 1 / attention_head_size**0.5
|
||||
self.dropout = attention_dropout
|
||||
|
||||
def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attention_mask: torch.Tensor):
|
||||
context = memory_efficient_attention(query, key, value, attention_mask, self.dropout, self.scale)
|
||||
context = rearrange(context, 'b s h d -> b s (h d)')
|
||||
return context
|
||||
|
Reference in New Issue
Block a user