[npu] support triangle attention for llama (#5130)

* update fused attn

* update spda

* tri attn

* update triangle

* import

* fix

* fix
This commit is contained in:
Xuanlei Zhao
2023-11-30 14:21:30 +08:00
committed by GitHub
parent f4e72c9992
commit d6df19bae7
9 changed files with 264 additions and 3 deletions

View File

@@ -280,3 +280,21 @@ def create_randomizer_with_offset(
Randomizer.increment_index()
return Randomizer(seed=base_seed)
def get_attention_kernel():
"""
Get the attention kernel based on the device type.
"""
from colossalai.kernel.cuda_native import AttnMaskType
if torch.cuda.is_available():
from colossalai.kernel.cuda_native import ColoAttention as AttentionKernel
else:
try:
torch.npu.is_available()
from colossalai.kernel.npu import NPUColoAttention as AttentionKernel
except:
raise Exception("No available device for attention kernel!")
return AttnMaskType, AttentionKernel