ColossalAI/extensions/flash_attention/__init__.py
Frank Lee 7cfed5f076
[feat] refactored extension module (#5298)
* [feat] refactored extension module

* polish

* polish

* polish

* polish

* polish

* polish

* polish

* polish

* polish

* polish
2024-01-25 17:01:48 +08:00

21 lines
527 B
Python

from .flash_attention_dao_cuda import FlashAttentionDaoCudaExtension
from .flash_attention_npu import FlashAttentionNpuExtension
from .flash_attention_xformers_cuda import FlashAttentionXformersCudaExtension
try:
import flash_attention # noqa
HAS_FLASH_ATTN = True
except:
HAS_FLASH_ATTN = False
try:
import xformers # noqa
HAS_MEM_EFF_ATTN = True
except:
HAS_MEM_EFF_ATTN = False
__all__ = ["FlashAttentionDaoCudaExtension", "FlashAttentionXformersCudaExtension", "FlashAttentionNpuExtension"]