diff --git a/colossalai/kernel/kernel_loader.py b/colossalai/kernel/kernel_loader.py index 8598cf0ae..36a49aae9 100644 --- a/colossalai/kernel/kernel_loader.py +++ b/colossalai/kernel/kernel_loader.py @@ -118,9 +118,11 @@ class FlashAttentionLoader(KernelLoader): FlashAttentionSdpaCudaExtension, ] + class FlashAttentionDaoLoader(KernelLoader): REGISTRY = [FlashAttentionDaoCudaExtension] + class FlashAttentionWithCustomMaskLoader(KernelLoader): REGISTRY = [FlashAttentionNpuExtension, FlashAttentionSdpaCudaExtension] diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index 29ef64a4d..0a4f98535 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -8,9 +8,9 @@ import torch.nn.functional as F from einops import rearrange from colossalai.kernel.kernel_loader import ( + FlashAttentionDaoLoader, FlashAttentionForFloatAndCustomMaskLoader, FlashAttentionLoader, - FlashAttentionDaoLoader, FlashAttentionWithCustomMaskLoader, KernelLoader, ) @@ -116,7 +116,7 @@ class ColoAttention: raise ValueError( "FlashAttention kernel is not available for dtype {} and mask_type {}".format(dtype, mask_type) ) - + if size > MEMORY_BOUND: FlashAttentionDaoLoader().load() # lazy load @@ -124,8 +124,10 @@ class ColoAttention: ColoAttention._kernel_dispatch_map[dtype][mask_type] = ColoAttention._kernel_dispatch_map[dtype][ mask_type ].load() - - return FlashAttentionDaoLoader() if size > MEMORY_BOUND else ColoAttention._kernel_dispatch_map[dtype][mask_type] + + return ( + FlashAttentionDaoLoader() if size > MEMORY_BOUND else ColoAttention._kernel_dispatch_map[dtype][mask_type] + ) @staticmethod def prepare_attn_kwargs( @@ -204,15 +206,15 @@ class ColoAttention: if invert: attention_mask = invert_mask(attention_mask).unsqueeze(1) outputs["attention_mask"] = attention_mask - + element_size = torch.tensor([], dtype=dtype).element_size() - memory_size = (s_q * s_kv * element_size) + memory_size = s_q * s_kv * element_size if memory_size > MEMORY_BOUND: attention_mask = torch.empty((0,), dtype=dtype, device=device) outputs["attention_mask"] = attention_mask if outputs["attention_mask_type"] != AttnMaskType.PADDED_CAUSAL: outputs["attention_mask_type"] = AttnMaskType.CAUSAL - + return outputs @staticmethod @@ -297,11 +299,11 @@ class ColoAttention: b, _, s_q, _ = q.shape b, _, s_kv, _ = v.shape element_size = torch.tensor([], dtype=q.dtype).element_size() - memory_size = (s_q * s_kv * element_size) + memory_size = s_q * s_kv * element_size if memory_size > MEMORY_BOUND: attention_mask = torch.empty((0,), dtype=q.dtype, device=q.device) assert attention_mask_type == AttnMaskType.PADDED_CAUSAL or AttnMaskType.PADDED - + mask_type = attention_mask_type if attention_mask is not None else None attn_func = ColoAttention._dispatch_kernel(q.dtype, mask_type, memory_size) is_causal = attention_mask is not None and attention_mask_type in (