[Fix] resolve conflicts of merging main

This commit is contained in:
Yuanheng
2024-04-08 16:21:47 +08:00
451 changed files with 15350 additions and 10694 deletions

View File

@@ -6,7 +6,7 @@ from .extensions import (
CpuAdamX86Extension,
FlashAttentionDaoCudaExtension,
FlashAttentionNpuExtension,
FlashAttentionXformersCudaExtension,
FlashAttentionSdpaCudaExtension,
FusedOptimizerCudaExtension,
InferenceOpsCudaExtension,
LayerNormCudaExtension,
@@ -67,9 +67,9 @@ class KernelLoader:
else:
usable_exts = []
for ext in exts:
if ext.is_hardware_available():
if ext.is_available():
# make sure the machine is compatible during kernel loading
ext.assert_hardware_compatible()
ext.assert_compatible()
usable_exts.append(ext)
assert len(usable_exts) != 0, f"No usable kernel found for {self.__class__.__name__} on the current machine."
@@ -112,4 +112,20 @@ class ScaledUpperTriangleMaskedSoftmaxLoader(KernelLoader):
class FlashAttentionLoader(KernelLoader):
REGISTRY = [FlashAttentionNpuExtension, FlashAttentionDaoCudaExtension, FlashAttentionXformersCudaExtension]
REGISTRY = [
FlashAttentionNpuExtension,
FlashAttentionDaoCudaExtension,
FlashAttentionSdpaCudaExtension,
]
class FlashAttentionWithPaddingMaskLoader(KernelLoader):
REGISTRY = [FlashAttentionNpuExtension, FlashAttentionDaoCudaExtension]
class FlashAttentionWithCustomMaskLoader(KernelLoader):
REGISTRY = [FlashAttentionNpuExtension, FlashAttentionSdpaCudaExtension]
class FlashAttentionForFloatAndCustomMaskLoader(KernelLoader):
REGISTRY = [FlashAttentionSdpaCudaExtension]