mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 11:32:10 +00:00
[npu] support triangle attention for llama (#5130)
* update fused attn * update spda * tri attn * update triangle * import * fix * fix
This commit is contained in:
@@ -280,3 +280,21 @@ def create_randomizer_with_offset(
|
||||
Randomizer.increment_index()
|
||||
|
||||
return Randomizer(seed=base_seed)
|
||||
|
||||
|
||||
def get_attention_kernel():
|
||||
"""
|
||||
Get the attention kernel based on the device type.
|
||||
"""
|
||||
from colossalai.kernel.cuda_native import AttnMaskType
|
||||
|
||||
if torch.cuda.is_available():
|
||||
from colossalai.kernel.cuda_native import ColoAttention as AttentionKernel
|
||||
else:
|
||||
try:
|
||||
torch.npu.is_available()
|
||||
from colossalai.kernel.npu import NPUColoAttention as AttentionKernel
|
||||
except:
|
||||
raise Exception("No available device for attention kernel!")
|
||||
|
||||
return AttnMaskType, AttentionKernel
|
||||
|
Reference in New Issue
Block a user