diff --git a/colossalai/kernel/kernel_loader.py b/colossalai/kernel/kernel_loader.py index 2411b6482..8598cf0ae 100644 --- a/colossalai/kernel/kernel_loader.py +++ b/colossalai/kernel/kernel_loader.py @@ -118,6 +118,8 @@ class FlashAttentionLoader(KernelLoader): FlashAttentionSdpaCudaExtension, ] +class FlashAttentionDaoLoader(KernelLoader): + REGISTRY = [FlashAttentionDaoCudaExtension] class FlashAttentionWithCustomMaskLoader(KernelLoader): REGISTRY = [FlashAttentionNpuExtension, FlashAttentionSdpaCudaExtension] diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index bf4fa77c6..29ef64a4d 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -10,6 +10,7 @@ from einops import rearrange from colossalai.kernel.kernel_loader import ( FlashAttentionForFloatAndCustomMaskLoader, FlashAttentionLoader, + FlashAttentionDaoLoader, FlashAttentionWithCustomMaskLoader, KernelLoader, ) @@ -17,6 +18,8 @@ from colossalai.logging import get_dist_logger from .utils import RingComm, get_half_index, split_varlen_zigzag +MEMORY_BOUND = 10 * 1e9 + __all__ = [ "AttnMaskType", "ColoAttention", @@ -104,7 +107,7 @@ class ColoAttention: } @staticmethod - def _dispatch_kernel(dtype: torch.dtype, mask_type: Optional[AttnMaskType]) -> Callable: + def _dispatch_kernel(dtype: torch.dtype, mask_type: Optional[AttnMaskType], size) -> Callable: ColoAttention._init_kernels_dispatch() if ( dtype not in ColoAttention._kernel_dispatch_map @@ -113,12 +116,16 @@ class ColoAttention: raise ValueError( "FlashAttention kernel is not available for dtype {} and mask_type {}".format(dtype, mask_type) ) + + if size > MEMORY_BOUND: + FlashAttentionDaoLoader().load() # lazy load if isinstance(ColoAttention._kernel_dispatch_map[dtype][mask_type], KernelLoader): ColoAttention._kernel_dispatch_map[dtype][mask_type] = ColoAttention._kernel_dispatch_map[dtype][ mask_type ].load() - return ColoAttention._kernel_dispatch_map[dtype][mask_type] + + return FlashAttentionDaoLoader() if size > MEMORY_BOUND else ColoAttention._kernel_dispatch_map[dtype][mask_type] @staticmethod def prepare_attn_kwargs( @@ -163,7 +170,7 @@ class ColoAttention: outputs["attention_mask_type"] = AttnMaskType.CAUSAL attention_mask = torch.ones(s_q, s_kv, dtype=dtype, device=device) if s_q != 1: - attention_mask = attention_mask.tril(diagonal=0) + attention_mask.tril_(diagonal=0) attention_mask = attention_mask.expand(b, s_q, s_kv) else: max_seqlen_q, cu_seqlens_q, q_indices = get_pad_info(q_padding_mask) @@ -197,6 +204,15 @@ class ColoAttention: if invert: attention_mask = invert_mask(attention_mask).unsqueeze(1) outputs["attention_mask"] = attention_mask + + element_size = torch.tensor([], dtype=dtype).element_size() + memory_size = (s_q * s_kv * element_size) + if memory_size > MEMORY_BOUND: + attention_mask = torch.empty((0,), dtype=dtype, device=device) + outputs["attention_mask"] = attention_mask + if outputs["attention_mask_type"] != AttnMaskType.PADDED_CAUSAL: + outputs["attention_mask_type"] = AttnMaskType.CAUSAL + return outputs @staticmethod @@ -278,8 +294,16 @@ class ColoAttention: assert attention_mask_type == AttnMaskType.CUSTOM # kernel dispatch + b, _, s_q, _ = q.shape + b, _, s_kv, _ = v.shape + element_size = torch.tensor([], dtype=q.dtype).element_size() + memory_size = (s_q * s_kv * element_size) + if memory_size > MEMORY_BOUND: + attention_mask = torch.empty((0,), dtype=q.dtype, device=q.device) + assert attention_mask_type == AttnMaskType.PADDED_CAUSAL or AttnMaskType.PADDED + mask_type = attention_mask_type if attention_mask is not None else None - attn_func = ColoAttention._dispatch_kernel(q.dtype, mask_type) + attn_func = ColoAttention._dispatch_kernel(q.dtype, mask_type, memory_size) is_causal = attention_mask is not None and attention_mask_type in ( AttnMaskType.CAUSAL, AttnMaskType.PADDED_CAUSAL,