From 7bc0afc901f2f0ce187cab9a0b1587740094d7b5 Mon Sep 17 00:00:00 2001 From: zbian Date: Fri, 17 Mar 2023 15:09:47 +0800 Subject: [PATCH] updated flash attention usage --- LICENSE | 70 ++++++ .../kernel/cuda_native/flash_attention.py | 207 +++++++++++++----- tests/test_utils/test_flash_attention.py | 190 +++++++--------- 3 files changed, 302 insertions(+), 165 deletions(-) diff --git a/LICENSE b/LICENSE index 394791da2..c7a5bb168 100644 --- a/LICENSE +++ b/LICENSE @@ -326,3 +326,73 @@ Copyright 2021- HPC-AI Technology Inc. All rights reserved. CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + ---------------- LICENSE FOR Flash Attention ---------------- + + BSD 3-Clause License + + Copyright (c) 2022, the respective contributors, as shown by the AUTHORS file. + All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are met: + + * Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + + * Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + + * Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + ---------------- LICENSE FOR Facebook xFormers ---------------- + + From xFormers: + + Copyright (c) Facebook, Inc. and its affiliates + + + === + + BSD 3-Clause License + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are met: + + 1. Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + 2. Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + + 3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America + and IDIAP Research Institute nor the names of its contributors may be + used to endorse or promote products derived from this software without + specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE + LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + POSSIBILITY OF SUCH DAMAGE. diff --git a/colossalai/kernel/cuda_native/flash_attention.py b/colossalai/kernel/cuda_native/flash_attention.py index 907fa640d..d793815ed 100644 --- a/colossalai/kernel/cuda_native/flash_attention.py +++ b/colossalai/kernel/cuda_native/flash_attention.py @@ -1,12 +1,6 @@ """ -The triton-based flash attention implementation is copied from the OpenAI/triton repository - -You can find the repository in Triton https://github.com/openai/triton -You can find the source file in https://github.com/openai/triton/blob/main/python/tutorials/06-fused-attention.py - -Reference: -1. Dao et al., https://arxiv.org/pdf/2205.14135v2.pdf -2. Rabe and Staats https://arxiv.org/pdf/2112.05682v2.pdf +A general attention module using the flash attention kernels from xformers: +https://github.com/facebookresearch/xformers/tree/main/xformers/ops/fmha """ import math @@ -15,6 +9,159 @@ import subprocess import torch +try: + from xformers.ops.fmha import memory_efficient_attention + HAS_MEM_EFF_ATTN = True +except ImportError: + HAS_MEM_EFF_ATTN = False + print('please install xformers from https://github.com/facebookresearch/xformers') + +if HAS_MEM_EFF_ATTN: + + from typing import Optional + + from einops import rearrange + from xformers.ops.fmha import MemoryEfficientAttentionCutlassOp + from xformers.ops.fmha.attn_bias import BlockDiagonalMask, LowerTriangularMask, LowerTriangularMaskWithTensorBias + + from .scaled_softmax import AttnMaskType + + allow_alibi = True + for op in MemoryEfficientAttentionCutlassOp: + allow_alibi = allow_alibi & (LowerTriangularMaskWithTensorBias in op.SUPPORTED_ATTN_BIAS_TYPES) + + class Unpad(torch.autograd.Function): + """ + Adapted from + https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/bert_padding.py + """ + + @staticmethod + def forward(ctx, tensor: torch.Tensor, indices: torch.Tensor): + ctx.save_for_backward(indices) + # [b, s, ...] + assert tensor.ndim >= 3 + ctx.bsz = tensor.shape[0] + out = rearrange(tensor, 'b s ... -> (b s) ...') + ctx.shape = out.shape + # [1, ntokens, ...] + return out[indices].unsqueeze(0) + + @staticmethod + def backward(ctx, grad_output): + indices, = ctx.saved_tensors + # [b*s, ...] + grad = torch.zeros(ctx.shape, dtype=grad_output.dtype, device=grad_output.device) + grad[indices] = grad_output.squeeze(0) + grad = rearrange(grad, '(b s) ... -> b s ...', b=ctx.bsz) + # [b, s, ...] + return grad, None + + class Repad(torch.autograd.Function): + """ + Adapted from + https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/bert_padding.py + """ + + @staticmethod + def forward(ctx, tensor: torch.Tensor, indices: torch.Tensor, batch_size: int, seq_len: int): + ctx.save_for_backward(indices) + # [ntokens, ...] + tensor = tensor.squeeze(0) + out = torch.zeros((batch_size * seq_len, *tensor.shape[1:]), dtype=tensor.dtype, device=tensor.device) + # [b*s, ...] + out[indices] = tensor + # [b, s, ...] + out = rearrange(out, '(b s) ... -> b s ...', b=batch_size) + return out + + @staticmethod + def backward(ctx, grad_output): + indices, = ctx.saved_tensors + # [b*s, ...] + grad_output = rearrange(grad_output, 'b s ... -> (b s) ...') + grad = grad_output[indices] + # [1, ntokens, ...] + return grad.unsqueeze(0), None, None, None + + class ColoAttention(torch.nn.Module): + + def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0): + super().__init__() + assert embed_dim % num_heads == 0, \ + f"the embed dim ({embed_dim}) is not divisible by the number of attention heads ({num_heads})." + self.scale = 1 / math.sqrt(embed_dim // num_heads) + self.dropout = dropout + + @staticmethod + def get_seq_info_from_mask(attn_mask: torch.Tensor): + indices = torch.nonzero(attn_mask.flatten(), as_tuple=False).flatten() + seqlens = attn_mask.sum(dim=-1, dtype=torch.int32).flatten().tolist() + return indices, seqlens + + @staticmethod + def unpad(tensor: torch.Tensor, indices: torch.Tensor) -> torch.Tensor: + return Unpad.apply(tensor, indices) + + @staticmethod + def repad(tensor: torch.Tensor, indices: torch.Tensor, batch_size: int, seq_len: int) -> torch.Tensor: + return Repad.apply(tensor, indices, batch_size, seq_len) + + def forward(self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + attn_mask_type: Optional[AttnMaskType] = None, + bias: Optional[torch.Tensor] = None): + batch_size, tgt_len, src_len = query.shape[0], query.shape[1], key.shape[1] + attn_bias = None + if attn_mask_type == AttnMaskType.padding: # bert style + assert attn_mask is not None, \ + f"attention mask {attn_mask} is not valid for attention mask type {attn_mask_type}." + assert attn_mask.dim() == 2, \ + "attention mask is supposed to have shape (batch_size, seq_len), " + \ + f"but got {attn_mask.dim()} dimensions." + if tgt_len == src_len: + q_indices, q_seqlen = self.get_seq_info_from_mask(attn_mask) + kv_seqlen = None + if batch_size > 1: + query, key, value = self.unpad(torch.stack([query, key, value], dim=2), q_indices).unbind(dim=2) + else: + q_indices = torch.arange(batch_size * tgt_len, dtype=torch.int32, device=query.device) + q_seqlen = torch.LongTensor([tgt_len] * batch_size, device=query.device) + kv_indices, kv_seqlen = self.get_seq_info_from_mask(attn_mask) + if batch_size > 1: + query = rearrange(query, "b s ... -> c (b s) ...", c=1) + key, value = self.unpad(torch.stack([query, key, value], dim=2), kv_indices).unbind(dim=2) + attn_bias = BlockDiagonalMask.from_seqlens(q_seqlen, kv_seqlen) + elif attn_mask_type == AttnMaskType.causal: # gpt style + attn_bias = LowerTriangularMask() + + if bias is not None: # alibi / relative position emebedding + assert allow_alibi, "flash attention with bias is not supported in this system." + assert attn_mask_type == AttnMaskType.causal, \ + "attention with bias is only supported for causal attention so far." + attn_bias = attn_bias.add_bias(bias) + + out = memory_efficient_attention(query, key, value, attn_bias=attn_bias, p=self.dropout, scale=self.scale) + + if attn_mask_type == AttnMaskType.padding and batch_size > 1: + out = self.repad(out, q_indices, batch_size, tgt_len) + + out = rearrange(out, 'b s h d -> b s (h d)') + return out + + +########################################################################## +# the flash attention functions below that are copied +# from the OpenAI/triton repository will be deprecated +# You can find the repository in Triton https://github.com/openai/triton +# You can find the source file in https://github.com/openai/triton/blob/main/python/tutorials/06-fused-attention.py +# Reference: +# 1. Dao et al., https://arxiv.org/pdf/2205.14135v2.pdf +# 2. Rabe and Staats https://arxiv.org/pdf/2112.05682v2.pdf + def triton_cuda_check(): cuda_home = os.getenv("CUDA_HOME", default="/usr/local/cuda") @@ -52,13 +199,6 @@ except ImportError: HAS_FLASH_ATTN = False print('please install flash_attn from https://github.com/HazyResearch/flash-attention') -try: - from xformers.ops.fmha import memory_efficient_attention - HAS_MEM_EFF_ATTN = True -except ImportError: - HAS_MEM_EFF_ATTN = False - print('please install xformers from https://github.com/facebookresearch/xformers') - if HAS_TRITON: # the following functions are adapted from the OpenAI Triton tutorial # https://github.com/openai/triton/blob/main/python/tutorials/06-fused-attention.py @@ -422,25 +562,6 @@ if HAS_TRITON: if HAS_FLASH_ATTN: - from einops import rearrange - - class MaskedFlashAttention(torch.nn.Module): - - def __init__(self, num_attention_heads: int, attention_head_size: int, attention_dropout: float) -> None: - super().__init__() - self.num_attention_heads = num_attention_heads - self.attention_head_size = attention_head_size - self.attention_func = FlashAttention(softmax_scale=math.sqrt(attention_head_size), - attention_dropout=attention_dropout) - - def forward(self, query_key_value: torch.Tensor, attention_mask: torch.Tensor, causal=False): - if attention_mask.dtype is not torch.bool: - attention_mask = attention_mask.bool() - qkv = rearrange(query_key_value, 'b s (three h d) -> b s three h d', three=3, h=self.num_attention_heads) - context, _ = self.attention_func(qkv, key_padding_mask=attention_mask, causal=causal) - context = rearrange(context, 'b s h d -> b s (h d)') - return context - def flash_attention_qkv(qkv, sm_scale, batch_size, seq_len, dropout_p=0., causal=False): """ Arguments: @@ -511,20 +632,4 @@ if HAS_FLASH_ATTN: causal) -if HAS_MEM_EFF_ATTN: - - from einops import rearrange - from xformers.ops.fmha import LowerTriangularMask - - class MemoryEfficientAttention(torch.nn.Module): - - def __init__(self, hidden_size: int, num_attention_heads: int, attention_dropout: float = 0.0): - super().__init__() - attention_head_size = hidden_size // num_attention_heads - self.scale = 1 / attention_head_size**0.5 - self.dropout = attention_dropout - - def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attention_mask: torch.Tensor): - context = memory_efficient_attention(query, key, value, attention_mask, self.dropout, self.scale) - context = rearrange(context, 'b s h d -> b s (h d)') - return context +########################################################################## diff --git a/tests/test_utils/test_flash_attention.py b/tests/test_utils/test_flash_attention.py index 58e3b21d9..441cbbb22 100644 --- a/tests/test_utils/test_flash_attention.py +++ b/tests/test_utils/test_flash_attention.py @@ -1,22 +1,13 @@ +import random + import pytest import torch from einops import rearrange -from colossalai.kernel.cuda_native.flash_attention import HAS_FLASH_ATTN, HAS_MEM_EFF_ATTN, HAS_TRITON - -if HAS_FLASH_ATTN: - from colossalai.kernel.cuda_native.flash_attention import ( - MaskedFlashAttention, - flash_attention_q_k_v, - flash_attention_q_kv, - flash_attention_qkv, - ) - -if HAS_TRITON: - from colossalai.kernel.cuda_native.flash_attention import triton_flash_attention +from colossalai.kernel.cuda_native.flash_attention import HAS_MEM_EFF_ATTN if HAS_MEM_EFF_ATTN: - from colossalai.kernel.cuda_native.flash_attention import LowerTriangularMask, MemoryEfficientAttention + from colossalai.kernel.cuda_native.flash_attention import AttnMaskType, ColoAttention def baseline_attention(Z, N_CTX, H, q, k, v, sm_scale): @@ -30,117 +21,88 @@ def baseline_attention(Z, N_CTX, H, q, k, v, sm_scale): return ref_out -@pytest.mark.skipif(HAS_TRITON == False, reason="triton is not available") -@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(3, 4, 2, 16)]) -def test_triton_flash_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16): - torch.manual_seed(20) - q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_() - k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_() - v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_() - sm_scale = 0.3 - dout = torch.randn_like(q) +@pytest.mark.skipif(HAS_MEM_EFF_ATTN == False, reason="xformers is not available") +@pytest.mark.parametrize('B, S, H, D_HEAD', [(6, 8, 4, 16)]) +def test_attention_gpt(B, S, H, D_HEAD, dtype=torch.float16): + D = H * D_HEAD - ref_out = baseline_attention(Z, N_CTX, H, q, k, v, sm_scale) - ref_out.backward(dout) - ref_dv, v.grad = v.grad.clone(), None - ref_dk, k.grad = k.grad.clone(), None - ref_dq, q.grad = q.grad.clone(), None + c_attn = torch.nn.Linear(D, 3 * D, dtype=dtype, device="cuda") + attn = ColoAttention(D, H, dropout=0.1) - # triton implementation - tri_out = triton_flash_attention(q, k, v, sm_scale) - tri_out.backward(dout) - tri_dv, v.grad = v.grad.clone(), None - tri_dk, k.grad = k.grad.clone(), None - tri_dq, q.grad = q.grad.clone(), None - # compare - assert torch.allclose(ref_out, tri_out, atol=1e-3) - assert torch.allclose(ref_dv, tri_dv, atol=1e-3) - assert torch.allclose(ref_dk, tri_dk, atol=1e-3) - assert torch.allclose(ref_dq, tri_dq, atol=1e-3) + x = torch.randn((B, S, D), dtype=dtype, device="cuda") + qkv = c_attn(x) + q, k, v = rearrange(qkv, 'b s (n h d) -> n b s h d', n=3, h=H) + y = attn(q, k, v, attn_mask_type=AttnMaskType.causal) -@pytest.mark.skipif(HAS_FLASH_ATTN == False, reason="flash is not available") -@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(3, 4, 2, 16)]) -def test_flash_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16): - torch.manual_seed(20) - q = torch.randn((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_() - k = torch.randn((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_() - v = torch.randn((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_() - sm_scale = 0.3 - dout = torch.randn_like(q) + assert list(y.shape) == [B, S, D] - # reference implementation - ref_out = baseline_attention(Z, N_CTX, H, q, k, v, sm_scale) - ref_out.backward(dout) - ref_dv, v.grad = v.grad.clone(), None - ref_dk, k.grad = k.grad.clone(), None - ref_dq, q.grad = q.grad.clone(), None - - # flash implementation - q, k, v = map(lambda x: rearrange(x, 'z h n d -> (z n) h d'), [q, k, v]) - dout = rearrange(dout, 'z h n d -> (z n) h d').detach() - for i in range(3): - if i == 0: - tri_out = flash_attention_q_k_v(q, k, v, sm_scale, Z, N_CTX, N_CTX, causal=True) - elif i == 1: - kv = torch.cat((k.unsqueeze(1), v.unsqueeze(1)), dim=1) - tri_out = flash_attention_q_kv(q, kv, sm_scale, Z, N_CTX, N_CTX, causal=True) - else: - qkv = torch.cat((q.unsqueeze(1), k.unsqueeze(1), v.unsqueeze(1)), dim=1) - tri_out = flash_attention_qkv(qkv, sm_scale, Z, N_CTX, causal=True) - - tri_out.backward(dout, retain_graph=True) - - if i == 0: - tri_dq, tri_dk, tri_dv, = torch.autograd.grad(tri_out, (q, k, v), dout) - tri_out, tri_dq, tri_dk, tri_dv = map(lambda x: rearrange(x, '(z n) h d -> z h n d', z=Z), - (tri_out, tri_dq, tri_dk, tri_dv)) - elif i == 1: - tri_dq, tri_dkv, = torch.autograd.grad(tri_out, (q, kv), dout) - tri_dk, tri_dv = torch.chunk(tri_dkv, 2, dim=1) - tri_out, tri_dq, tri_dk, tri_dv = map(lambda x: rearrange(x, '(z n) h d -> z h n d', z=Z), - (tri_out, tri_dq, tri_dk.squeeze(1), tri_dv.squeeze(1))) - else: - tri_dqkv, = torch.autograd.grad(tri_out, (qkv), dout) - tri_dq, tri_dk, tri_dv = torch.chunk(tri_dqkv, 3, dim=1) - tri_out, tri_dq, tri_dk, tri_dv = map(lambda x: rearrange(x, '(z n) h d -> z h n d', z=Z), - (tri_out, tri_dq.squeeze(1), tri_dk.squeeze(1), tri_dv.squeeze(1))) - - # compare - assert torch.allclose(ref_out, tri_out, atol=1e-3) - assert torch.allclose(ref_dv, tri_dv, atol=1e-3) - assert torch.allclose(ref_dk, tri_dk, atol=1e-3) - assert torch.allclose(ref_dq, tri_dq, atol=1e-3) - - -@pytest.mark.skipif(HAS_FLASH_ATTN == False, reason="flash is not available") -@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(3, 4, 2, 16)]) -def test_masked_flash_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16): - attn = MaskedFlashAttention(N_CTX, D_HEAD, 0.1) - - qkv = torch.randn((Z, H, 3 * N_CTX * D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_() - attention_mask = torch.randint(2, (Z, H)).cuda().bool() - - out = attn(qkv, attention_mask) - - dout = torch.rand_like(out) - out.backward(dout) + dy = torch.rand_like(y) + y.backward(dy) @pytest.mark.skipif(HAS_MEM_EFF_ATTN == False, reason="xformers is not available") -@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(6, 8, 4, 16)]) -def test_memory_efficient_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16): - attn = MemoryEfficientAttention(N_CTX * D_HEAD, N_CTX, 0.1) +@pytest.mark.parametrize('B, S, H, D_HEAD', [(6, 8, 4, 16)]) +def test_attention_bert(B, S, H, D_HEAD, dtype=torch.float16): + D = H * D_HEAD - q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_() - k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_() - v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_() + c_attn = torch.nn.Linear(D, 3 * D, dtype=dtype, device="cuda") + attn = ColoAttention(D, H, dropout=0.1) - out = attn(q, k, v, attention_mask=LowerTriangularMask()) + x = torch.randn((B, S, D), dtype=dtype, device="cuda") + # attention mask of shape [B, S] with zero padding to max length S + mask = [torch.ones(S - i, dtype=dtype, device="cuda") for i in range(B)] + mask = torch.nn.utils.rnn.pad_sequence(mask, batch_first=True) - dout = torch.rand_like(out) - out.backward(dout) + qkv = c_attn(x) + q, k, v = rearrange(qkv, 'b s (n h d) -> b s n h d', n=3, h=H).unbind(dim=2) + y = attn(q, k, v, attn_mask=mask, attn_mask_type=AttnMaskType.padding) + + assert list(y.shape) == [B, S, D] + + dy = torch.rand_like(y) + y.backward(dy) -if __name__ == '__main__': - test_flash_attention(3, 4, 2, 16) +@pytest.mark.skipif(HAS_MEM_EFF_ATTN == False, reason="xformers is not available") +@pytest.mark.parametrize('B, S, H, D_HEAD', [(6, 8, 4, 16)]) +def test_attention_no_mask(B, S, H, D_HEAD, dtype=torch.float16): + D = H * D_HEAD + + c_attn = torch.nn.Linear(D, 3 * D, dtype=dtype, device="cuda") + attn = ColoAttention(D, H, dropout=0.1) + + x = torch.randn((B, S, D), dtype=dtype, device="cuda") + qkv = c_attn(x) + q, k, v = rearrange(qkv, 'b s (n h d) -> b s n h d', n=3, h=H).unbind(dim=2) + y = attn(q, k, v) + + assert list(y.shape) == [B, S, D] + + dy = torch.rand_like(y) + y.backward(dy) + + +@pytest.mark.skipif(HAS_MEM_EFF_ATTN == False, reason="xformers is not available") +@pytest.mark.parametrize('B, S, T, H, D_HEAD', [(6, 24, 8, 4, 16)]) +def test_cross_attention(B, S, T, H, D_HEAD, dtype=torch.float16): + D = H * D_HEAD + + q_attn = torch.nn.Linear(D, D, dtype=dtype, device="cuda") + kv_attn = torch.nn.Linear(D, 2 * D, dtype=dtype, device="cuda") + + attn = ColoAttention(D, H, dropout=0.1) + + src = torch.randn((B, S, D), dtype=dtype, device="cuda") + tgt = torch.randn((B, T, D), dtype=dtype, device="cuda") + + q = q_attn(tgt) + kv = kv_attn(src) + q = rearrange(q, 'b s (h d) -> b s h d', h=H) + k, v = rearrange(kv, 'b s (n h d) -> b s n h d', n=2, h=H).unbind(dim=2) + y = attn(q, k, v, attn_mask_type=AttnMaskType.causal) + + assert list(y.shape) == [B, T, D] + + dy = torch.rand_like(y) + y.backward(dy)