mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 03:20:52 +00:00
[shardformer] update colo attention to support custom mask (#5510)
* [feature] refactor colo attention (#5462) * [extension] update api * [feature] add colo attention * [feature] update sdpa * [feature] update npu attention * [feature] update flash-attn * [test] add flash attn test * [test] update flash attn test * [shardformer] update modeling to fit colo attention (#5465) * [misc] refactor folder structure * [shardformer] update llama flash-attn * [shardformer] fix llama policy * [devops] update tensornvme install * [test] update llama test * [shardformer] update colo attn kernel dispatch * [shardformer] update blip2 * [shardformer] update chatglm * [shardformer] update gpt2 * [shardformer] update gptj * [shardformer] update opt * [shardformer] update vit * [shardformer] update colo attention mask prep * [shardformer] update whisper * [test] fix shardformer tests (#5514) * [test] fix shardformer tests * [test] fix shardformer tests
This commit is contained in:
@@ -6,7 +6,7 @@ from .extensions import (
|
||||
CpuAdamX86Extension,
|
||||
FlashAttentionDaoCudaExtension,
|
||||
FlashAttentionNpuExtension,
|
||||
FlashAttentionXformersCudaExtension,
|
||||
FlashAttentionSdpaCudaExtension,
|
||||
FusedOptimizerCudaExtension,
|
||||
LayerNormCudaExtension,
|
||||
MoeCudaExtension,
|
||||
@@ -65,9 +65,9 @@ class KernelLoader:
|
||||
else:
|
||||
usable_exts = []
|
||||
for ext in exts:
|
||||
if ext.is_hardware_available():
|
||||
if ext.is_available():
|
||||
# make sure the machine is compatible during kernel loading
|
||||
ext.assert_hardware_compatible()
|
||||
ext.assert_compatible()
|
||||
usable_exts.append(ext)
|
||||
|
||||
assert len(usable_exts) != 0, f"No usable kernel found for {self.__class__.__name__} on the current machine."
|
||||
@@ -106,4 +106,20 @@ class ScaledUpperTriangleMaskedSoftmaxLoader(KernelLoader):
|
||||
|
||||
|
||||
class FlashAttentionLoader(KernelLoader):
|
||||
REGISTRY = [FlashAttentionNpuExtension, FlashAttentionDaoCudaExtension, FlashAttentionXformersCudaExtension]
|
||||
REGISTRY = [
|
||||
FlashAttentionNpuExtension,
|
||||
FlashAttentionDaoCudaExtension,
|
||||
FlashAttentionSdpaCudaExtension,
|
||||
]
|
||||
|
||||
|
||||
class FlashAttentionWithPaddingMaskLoader(KernelLoader):
|
||||
REGISTRY = [FlashAttentionNpuExtension, FlashAttentionDaoCudaExtension]
|
||||
|
||||
|
||||
class FlashAttentionWithCustomMaskLoader(KernelLoader):
|
||||
REGISTRY = [FlashAttentionNpuExtension, FlashAttentionSdpaCudaExtension]
|
||||
|
||||
|
||||
class FlashAttentionForFloatAndCustomMaskLoader(KernelLoader):
|
||||
REGISTRY = [FlashAttentionSdpaCudaExtension]
|
||||
|
Reference in New Issue
Block a user