[Device]Support npu (#6159)

* support npu

* support pretrain

support pretrain

fix

* support lora

fix

fix

* support chatglm

fix

fxi

fix

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

fix

fix

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

fix

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

fix

fix

fix

* Update train.py

* Update train.py

* fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
flybird11111
2024-12-17 15:42:39 +08:00
committed by GitHub
parent e994c64568
commit aaafb38851
18 changed files with 295 additions and 152 deletions

View File

@@ -509,9 +509,9 @@ class LazyInitContext:
# factory_like functions (eg. torch.empty_like())
def wrapper(*args, **kwargs):
orig_t = args[0]
return self.tensor_cls(
orig_target, *orig_t.shape, *args[1:], device=orig_t.device, dtype=orig_t.dtype, **kwargs
)
device = kwargs.pop("device", orig_t.device)
dtype = kwargs.pop("dtype", orig_t.dtype)
return self.tensor_cls(orig_target, *orig_t.shape, *args[1:], device=device, dtype=dtype, **kwargs)
return wrapper, target