diff --git a/colossalai/kernel/kernel_loader.py b/colossalai/kernel/kernel_loader.py index 2411b6482..36a49aae9 100644 --- a/colossalai/kernel/kernel_loader.py +++ b/colossalai/kernel/kernel_loader.py @@ -119,6 +119,10 @@ class FlashAttentionLoader(KernelLoader): ] +class FlashAttentionDaoLoader(KernelLoader): + REGISTRY = [FlashAttentionDaoCudaExtension] + + class FlashAttentionWithCustomMaskLoader(KernelLoader): REGISTRY = [FlashAttentionNpuExtension, FlashAttentionSdpaCudaExtension] diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index bf4fa77c6..2f8e4d677 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -8,6 +8,7 @@ import torch.nn.functional as F from einops import rearrange from colossalai.kernel.kernel_loader import ( + FlashAttentionDaoLoader, FlashAttentionForFloatAndCustomMaskLoader, FlashAttentionLoader, FlashAttentionWithCustomMaskLoader, @@ -17,6 +18,8 @@ from colossalai.logging import get_dist_logger from .utils import RingComm, get_half_index, split_varlen_zigzag +MEMORY_BOUND = 10 * 1e9 + __all__ = [ "AttnMaskType", "ColoAttention", @@ -77,6 +80,7 @@ def get_pad_info( class ColoAttention: _kernel_dispatch_map: Optional[Dict[torch.dtype, Dict[Optional[AttnMaskType], Callable]]] = None + _flash_kernel_dispatch: Optional[Dict[torch.dtype, Dict[Optional[AttnMaskType], Callable]]] = None @staticmethod def _init_kernels_dispatch(): @@ -102,9 +106,11 @@ class ColoAttention: torch.bfloat16: half_dispatch_map, torch.float32: float_dispatch_map, } + if ColoAttention._flash_kernel_dispatch is None: + ColoAttention._flash_kernel_dispatch = FlashAttentionDaoLoader() @staticmethod - def _dispatch_kernel(dtype: torch.dtype, mask_type: Optional[AttnMaskType]) -> Callable: + def _dispatch_kernel(dtype: torch.dtype, mask_type: Optional[AttnMaskType], size) -> Callable: ColoAttention._init_kernels_dispatch() if ( dtype not in ColoAttention._kernel_dispatch_map @@ -113,12 +119,19 @@ class ColoAttention: raise ValueError( "FlashAttention kernel is not available for dtype {} and mask_type {}".format(dtype, mask_type) ) + + if size >= MEMORY_BOUND: + ColoAttention._flash_kernel_dispatch = ColoAttention._flash_kernel_dispatch.load() # lazy load if isinstance(ColoAttention._kernel_dispatch_map[dtype][mask_type], KernelLoader): ColoAttention._kernel_dispatch_map[dtype][mask_type] = ColoAttention._kernel_dispatch_map[dtype][ mask_type ].load() - return ColoAttention._kernel_dispatch_map[dtype][mask_type] + + if size >= MEMORY_BOUND and mask_type in (AttnMaskType.PADDED_CAUSAL, AttnMaskType.CAUSAL): + return ColoAttention._flash_kernel_dispatch + else: + return ColoAttention._kernel_dispatch_map[dtype][mask_type] @staticmethod def prepare_attn_kwargs( @@ -154,6 +167,8 @@ class ColoAttention: return {} assert len(shape_4d) == 4 and shape_4d[1] == 1 b, _, s_q, s_kv = shape_4d + element_size = torch.tensor([], dtype=dtype).element_size() + memory_size = s_q * s_kv * element_size outputs = {} if (q_padding_mask is None or q_padding_mask.bool().all()) and ( kv_padding_mask is None or kv_padding_mask.bool().all() @@ -161,10 +176,13 @@ class ColoAttention: # no padding assert is_causal outputs["attention_mask_type"] = AttnMaskType.CAUSAL - attention_mask = torch.ones(s_q, s_kv, dtype=dtype, device=device) - if s_q != 1: - attention_mask = attention_mask.tril(diagonal=0) - attention_mask = attention_mask.expand(b, s_q, s_kv) + if memory_size < MEMORY_BOUND: + attention_mask = torch.ones(s_q, s_kv, dtype=dtype, device=device) + if s_q != 1: + attention_mask.tril_(diagonal=0) + attention_mask = attention_mask.expand(b, s_q, s_kv) + else: + attention_mask = torch.empty((0,), dtype=dtype, device=device) else: max_seqlen_q, cu_seqlens_q, q_indices = get_pad_info(q_padding_mask) if kv_padding_mask is None: @@ -177,7 +195,6 @@ class ColoAttention: b, s_kv, ), f"Padding mask shape {kv_padding_mask.shape} should align with shape 4d ({b}, {s_kv})" - attention_mask = kv_padding_mask[:, None, :].expand(b, s_q, s_kv).to(dtype=dtype, device=device) outputs.update( { "cu_seqlens_q": cu_seqlens_q, @@ -190,10 +207,17 @@ class ColoAttention: ) if is_causal: outputs["attention_mask_type"] = AttnMaskType.PADDED_CAUSAL - if s_q != 1: - attention_mask = attention_mask * attention_mask.new_ones(s_q, s_kv).tril(diagonal=0) + if memory_size < MEMORY_BOUND: + if s_q != 1: + attention_mask = kv_padding_mask[:, None, :].expand(b, s_q, s_kv).to(dtype=dtype, device=device) + attention_mask = attention_mask * attention_mask.new_ones(s_q, s_kv).tril(diagonal=0) + else: + attention_mask = torch.empty((0,), dtype=dtype, device=device) else: outputs["attention_mask_type"] = AttnMaskType.PADDED + if memory_size < MEMORY_BOUND: + attention_mask = kv_padding_mask[:, None, :].expand(b, s_q, s_kv).to(dtype=dtype, device=device) + if invert: attention_mask = invert_mask(attention_mask).unsqueeze(1) outputs["attention_mask"] = attention_mask @@ -278,8 +302,12 @@ class ColoAttention: assert attention_mask_type == AttnMaskType.CUSTOM # kernel dispatch + b, _, s_q, _ = q.shape + b, _, s_kv, _ = v.shape + element_size = torch.tensor([], dtype=q.dtype).element_size() + memory_size = s_q * s_kv * element_size mask_type = attention_mask_type if attention_mask is not None else None - attn_func = ColoAttention._dispatch_kernel(q.dtype, mask_type) + attn_func = ColoAttention._dispatch_kernel(q.dtype, mask_type, memory_size) is_causal = attention_mask is not None and attention_mask_type in ( AttnMaskType.CAUSAL, AttnMaskType.PADDED_CAUSAL,