mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-15 06:00:07 +00:00
@@ -9,7 +9,7 @@ from colossalai.kernel.cuda_native.mha.mem_eff_attn import HAS_MEM_EFF_ATTN
|
||||
from colossalai.testing import clear_cache_before_run, parameterize
|
||||
|
||||
if HAS_MEM_EFF_ATTN or HAS_FLASH_ATTN:
|
||||
from colossalai.kernel.cuda_native.mha.mha import ColoAttention
|
||||
from colossalai.kernel.cuda_native import ColoAttention
|
||||
from colossalai.kernel.cuda_native.scaled_softmax import AttnMaskType
|
||||
|
||||
DTYPE = [torch.float16, torch.bfloat16, torch.float32]
|
||||
|
Reference in New Issue
Block a user