[NPU]support npu (#6089)

* support npu

* support pretrain

support pretrain

fix

* support lora

fix

fix

* support chatglm

fix

fxi

fix

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

fix

fix

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

fix

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

fix

fix

fix

* Update train.py

* Update train.py
This commit is contained in:
flybird11111
2024-11-20 15:28:35 +08:00
committed by GitHub
parent 4c8e85ee0d
commit 64f74a157e
18 changed files with 296 additions and 153 deletions

View File

@@ -27,6 +27,7 @@ class FlashAttentionNpuExtension(_Extension):
)
def load(self):
import math
from typing import Optional
import torch
@@ -47,6 +48,8 @@ class FlashAttentionNpuExtension(_Extension):
q_indices: Optional[torch.Tensor] = None,
kv_indices: Optional[torch.Tensor] = None,
):
if scale is None:
scale = 1.0 / math.sqrt(q.size(-1))
num_heads = q.size(1)
return torch_npu.npu_fusion_attention(
q,