From 7a3dfd0c645fba51a02eb3c6ac88b4f09160ea7d Mon Sep 17 00:00:00 2001 From: flybird1111 <1829166702@qq.com> Date: Wed, 9 Aug 2023 14:32:19 +0800 Subject: [PATCH] [shardformer] update shardformer to use flash attention 2 (#4392) * cherry-pick flash attention 2 cherry-pick flash attention 2 * [shardformer] update shardformer to use flash attention 2 [shardformer] update shardformer to use flash attention 2, fix [shardformer] update shardformer to use flash attention 2, fix [shardformer] update shardformer to use flash attention 2, fix --- colossalai/kernel/cuda_native/__init__.py | 5 +++-- colossalai/shardformer/modeling/blip2.py | 2 +- colossalai/shardformer/modeling/chatglm.py | 3 +-- colossalai/shardformer/modeling/gpt2.py | 2 +- colossalai/shardformer/modeling/llama.py | 2 +- colossalai/shardformer/modeling/opt.py | 2 +- colossalai/shardformer/modeling/vit.py | 2 +- colossalai/shardformer/modeling/whisper.py | 2 +- tests/test_utils/test_flash_attention.py | 1 - 9 files changed, 10 insertions(+), 11 deletions(-) diff --git a/colossalai/kernel/cuda_native/__init__.py b/colossalai/kernel/cuda_native/__init__.py index 4910717b5..e0136d86e 100644 --- a/colossalai/kernel/cuda_native/__init__.py +++ b/colossalai/kernel/cuda_native/__init__.py @@ -1,8 +1,9 @@ from .layer_norm import MixedFusedLayerNorm as LayerNorm from .mha.mha import ColoAttention from .multihead_attention import MultiHeadAttention -from .scaled_softmax import FusedScaleMaskSoftmax, ScaledUpperTriangMaskedSoftmax +from .scaled_softmax import AttnMaskType, FusedScaleMaskSoftmax, ScaledUpperTriangMaskedSoftmax __all__ = [ - 'LayerNorm', 'MultiHeadAttention', 'FusedScaleMaskSoftmax', 'ScaledUpperTriangMaskedSoftmax', 'ColoAttention' + 'LayerNorm', 'MultiHeadAttention', 'FusedScaleMaskSoftmax', 'ScaledUpperTriangMaskedSoftmax', 'ColoAttention', + 'AttnMaskType' ] diff --git a/colossalai/shardformer/modeling/blip2.py b/colossalai/shardformer/modeling/blip2.py index c5c6b14ba..69730fd3d 100644 --- a/colossalai/shardformer/modeling/blip2.py +++ b/colossalai/shardformer/modeling/blip2.py @@ -65,7 +65,7 @@ def get_blip2_flash_attention_forward(): from transformers.models.blip_2.modeling_blip_2 import Blip2Attention - from colossalai.kernel.cuda_native.flash_attention import AttnMaskType, ColoAttention + from colossalai.kernel.cuda_native import ColoAttention def forward( self: Blip2Attention, diff --git a/colossalai/shardformer/modeling/chatglm.py b/colossalai/shardformer/modeling/chatglm.py index 3d453c3bd..a95966c3b 100644 --- a/colossalai/shardformer/modeling/chatglm.py +++ b/colossalai/shardformer/modeling/chatglm.py @@ -19,7 +19,7 @@ from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ( def get_flash_core_attention_forward(): - from colossalai.kernel.cuda_native.flash_attention import AttnMaskType, ColoAttention + from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention from .chatglm2_6b.modeling_chatglm import CoreAttention @@ -126,7 +126,6 @@ def get_jit_fused_glm_block_forward(): return forward - class ChatGLMPipelineForwards: ''' This class serves as a micro library for ChatGLM model forwards under pipeline parallelism. diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py index e02581fba..a12a9796f 100644 --- a/colossalai/shardformer/modeling/gpt2.py +++ b/colossalai/shardformer/modeling/gpt2.py @@ -674,7 +674,7 @@ def get_gpt2_flash_attention_forward(): from transformers.models.gpt2.modeling_gpt2 import GPT2Attention - from colossalai.kernel.cuda_native.flash_attention import AttnMaskType, ColoAttention + from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention def split_heads(tensor, num_heads, attn_head_size): """ diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 9d6335503..2f54daac5 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -392,7 +392,7 @@ def get_llama_flash_attention_forward(): from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb - from colossalai.kernel.cuda_native.flash_attention import AttnMaskType, ColoAttention + from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention def forward( self: LlamaAttention, diff --git a/colossalai/shardformer/modeling/opt.py b/colossalai/shardformer/modeling/opt.py index 299dfb556..bdf141816 100644 --- a/colossalai/shardformer/modeling/opt.py +++ b/colossalai/shardformer/modeling/opt.py @@ -8,7 +8,7 @@ def get_opt_flash_attention_forward(): from transformers.models.opt.modeling_opt import OPTAttention - from colossalai.kernel.cuda_native.flash_attention import AttnMaskType, ColoAttention + from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention def forward( self: OPTAttention, diff --git a/colossalai/shardformer/modeling/vit.py b/colossalai/shardformer/modeling/vit.py index 22c4dd998..eb0ea4c75 100644 --- a/colossalai/shardformer/modeling/vit.py +++ b/colossalai/shardformer/modeling/vit.py @@ -342,7 +342,7 @@ def get_vit_flash_self_attention_forward(): from transformers.models.vit.modeling_vit import ViTSelfAttention - from colossalai.kernel.cuda_native.flash_attention import ColoAttention + from colossalai.kernel.cuda_native import ColoAttention def transpose_for_scores(x: torch.Tensor, num_attention_heads, attention_head_size) -> torch.Tensor: new_x_shape = x.size()[:-1] + (num_attention_heads, attention_head_size) diff --git a/colossalai/shardformer/modeling/whisper.py b/colossalai/shardformer/modeling/whisper.py index 6bc387ac8..0a16c6f78 100644 --- a/colossalai/shardformer/modeling/whisper.py +++ b/colossalai/shardformer/modeling/whisper.py @@ -8,7 +8,7 @@ def get_whisper_flash_attention_forward(): from transformers.models.whisper.modeling_whisper import WhisperAttention - from colossalai.kernel.cuda_native.flash_attention import AttnMaskType, ColoAttention + from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention def shape(tensor: torch.Tensor, seq_len: int, bsz: int, num_heads: int, head_dim: int): return tensor.view(bsz, seq_len, num_heads, head_dim).contiguous() diff --git a/tests/test_utils/test_flash_attention.py b/tests/test_utils/test_flash_attention.py index 28369d4c9..f775710c4 100644 --- a/tests/test_utils/test_flash_attention.py +++ b/tests/test_utils/test_flash_attention.py @@ -13,7 +13,6 @@ if HAS_MEM_EFF_ATTN or HAS_FLASH_ATTN: from colossalai.kernel.cuda_native.scaled_softmax import AttnMaskType DTYPE = [torch.float16, torch.bfloat16, torch.float32] -FLASH_DTYPE = [torch.float16, torch.bfloat16] def attention_ref(q, k, v, attn_mask=None, causal=False):