mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 01:28:31 +00:00
[Device]Support npu (#6159)
* support npu * support pretrain support pretrain fix * support lora fix fix * support chatglm fix fxi fix [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci fix fix [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci fix [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci fix fix fix * Update train.py * Update train.py * fix * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * fix * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -1,3 +1,5 @@
|
||||
import math
|
||||
|
||||
from ...base_extension import _Extension
|
||||
|
||||
|
||||
@@ -47,6 +49,8 @@ class FlashAttentionNpuExtension(_Extension):
|
||||
q_indices: Optional[torch.Tensor] = None,
|
||||
kv_indices: Optional[torch.Tensor] = None,
|
||||
):
|
||||
if scale is None:
|
||||
scale = 1.0 / math.sqrt(q.size(-1))
|
||||
num_heads = q.size(1)
|
||||
return torch_npu.npu_fusion_attention(
|
||||
q,
|
||||
|
Reference in New Issue
Block a user