[npu] support triangle attention for llama (#5130)

* update fused attn

* update spda

* tri attn

* update triangle

* import

* fix

* fix
This commit is contained in:
Xuanlei Zhao
2023-11-30 14:21:30 +08:00
committed by GitHub
parent f4e72c9992
commit d6df19bae7
9 changed files with 264 additions and 3 deletions

View File

@@ -29,7 +29,6 @@ except ImportError:
HAS_FLASH_ATTN = False
if HAS_FLASH_ATTN:
pass
from .utils import SeqLenInfo

View File

@@ -44,6 +44,7 @@ class ColoAttention(torch.nn.Module):
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
origin_attn_mask: Optional[torch.Tensor] = None,
attn_mask_type: Optional[AttnMaskType] = None,
bias: Optional[torch.Tensor] = None,
):