diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index a2ea761bf..c755ffa2f 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -118,7 +118,7 @@ class ColoAttention: ) if size >= MEMORY_BOUND: - FlashAttentionDaoLoader().load() + flash_kernel = FlashAttentionDaoLoader().load() # lazy load if isinstance(ColoAttention._kernel_dispatch_map[dtype][mask_type], KernelLoader): ColoAttention._kernel_dispatch_map[dtype][mask_type] = ColoAttention._kernel_dispatch_map[dtype][ @@ -126,7 +126,7 @@ class ColoAttention: ].load() if size >= MEMORY_BOUND and mask_type in (AttnMaskType.PADDED_CAUSAL, AttnMaskType.CAUSAL): - return FlashAttentionDaoLoader() + return flash_kernel else: return ColoAttention._kernel_dispatch_map[dtype][mask_type]