updated flash attention api

This commit is contained in:
zbian
2022-11-14 17:11:33 +08:00
committed by アマデウス
parent 36c0f3ea5b
commit 6877121377
3 changed files with 64 additions and 36 deletions

View File

@@ -1,3 +1,3 @@
from .layer_norm import MixedFusedLayerNorm as LayerNorm
from .scaled_softmax import FusedScaleMaskSoftmax
from .multihead_attention import MultiHeadAttention
from .scaled_softmax import FusedScaleMaskSoftmax

View File

@@ -5,6 +5,7 @@ This is a Triton implementation of the Flash Attention algorithm
(see: Dao et al., https://arxiv.org/pdf/2205.14135v2.pdf; Rabe and Staats https://arxiv.org/pdf/2112.05682v2.pdf; Triton https://github.com/openai/triton)
"""
import math
import os
import subprocess
@@ -36,17 +37,17 @@ except ImportError:
print('please install triton from https://github.com/openai/triton')
HAS_TRITON = False
try:
from flash_attn.flash_attention import FlashAttention
from flash_attn.flash_attn_interface import (
flash_attn_unpadded_func,
flash_attn_unpadded_kvpacked_func,
flash_attn_unpadded_qkvpacked_func,
flash_attn_unpadded_qkvpacked_func,
)
HAS_FLASH_ATTN = True
except ImportError:
HAS_FLASH_ATTN = False
print('please install flash_attn from https://github.com/HazyResearch/flash-attention')
if HAS_TRITON:
@triton.jit
@@ -409,6 +410,25 @@ if HAS_TRITON:
if HAS_FLASH_ATTN:
from einops import rearrange
class MaskedFlashAttention(torch.nn.Module):
def __init__(self, num_attention_heads: int, attention_head_size: int, attention_dropout: float) -> None:
super().__init__()
self.num_attention_heads = num_attention_heads
self.attention_head_size = attention_head_size
self.attention_func = FlashAttention(softmax_scale=math.sqrt(attention_head_size),
attention_dropout=attention_dropout)
def forward(self, query_key_value: torch.Tensor, attention_mask: torch.Tensor, causal=False):
if attention_mask.dtype is not torch.bool:
attention_mask = attention_mask.bool()
qkv = rearrange(query_key_value, 'b s (three h d) -> b s three h d', three=3, h=self.num_attention_heads)
context, _ = self.attention_func(qkv, key_padding_mask=attention_mask, causal=causal)
context = rearrange(context, 'b s h d -> b s (h d)')
return context
def flash_attention_qkv(qkv, sm_scale, batch_size, seq_len, dropout_p=0., causal=False):
"""
Arguments:
@@ -423,15 +443,15 @@ if HAS_FLASH_ATTN:
out: (total, nheads, headdim).
"""
max_s = seq_len
cu_seqlens = torch.arange(0, (batch_size + 1) * seq_len, step=seq_len, dtype=torch.int32,
device=qkv.device)
out = flash_attn_unpadded_qkvpacked_func(
qkv, cu_seqlens, max_s, dropout_p,
softmax_scale=sm_scale, causal=causal
)
cu_seqlens = torch.arange(0, (batch_size + 1) * seq_len, step=seq_len, dtype=torch.int32, device=qkv.device)
out = flash_attn_unpadded_qkvpacked_func(qkv,
cu_seqlens,
max_s,
dropout_p,
softmax_scale=sm_scale,
causal=causal)
return out
def flash_attention_q_kv(q, kv, sm_scale, batch_size, q_seqlen, kv_seqlen, dropout_p=0., causal=False):
"""
Arguments:
@@ -447,19 +467,14 @@ if HAS_FLASH_ATTN:
out: (total, nheads, headdim).
"""
cu_seqlens_q = torch.arange(0, (batch_size + 1) * q_seqlen, step=q_seqlen, dtype=torch.int32, device=q.device)
cu_seqlens_k = torch.arange(0, (batch_size + 1) * kv_seqlen, step=kv_seqlen, dtype=torch.int32, device=kv.device)
out = flash_attn_unpadded_kvpacked_func(q,
kv,
cu_seqlens_q,
cu_seqlens_k,
q_seqlen,
kv_seqlen,
dropout_p,
sm_scale,
causal)
cu_seqlens_k = torch.arange(0, (batch_size + 1) * kv_seqlen,
step=kv_seqlen,
dtype=torch.int32,
device=kv.device)
out = flash_attn_unpadded_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_k, q_seqlen, kv_seqlen, dropout_p,
sm_scale, causal)
return out
def flash_attention_q_k_v(q, k, v, sm_scale, batch_size, q_seqlen, kv_seqlen, dropout_p=0., causal=False):
"""
Arguments:
@@ -476,14 +491,9 @@ if HAS_FLASH_ATTN:
out: (total, nheads, headdim).
"""
cu_seqlens_q = torch.arange(0, (batch_size + 1) * q_seqlen, step=q_seqlen, dtype=torch.int32, device=q.device)
cu_seqlens_kv = torch.arange(0, (batch_size + 1) * kv_seqlen, step=kv_seqlen, dtype=torch.int32, device=k.device)
return flash_attn_unpadded_func(q,
k,
v,
cu_seqlens_q,
cu_seqlens_kv,
q_seqlen,
kv_seqlen,
dropout_p,
sm_scale,
cu_seqlens_kv = torch.arange(0, (batch_size + 1) * kv_seqlen,
step=kv_seqlen,
dtype=torch.int32,
device=k.device)
return flash_attn_unpadded_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, q_seqlen, kv_seqlen, dropout_p, sm_scale,
causal)