[Device]Support npu (#6159)

* support npu

* support pretrain

support pretrain

fix

* support lora

fix

fix

* support chatglm

fix

fxi

fix

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

fix

fix

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

fix

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

fix

fix

fix

* Update train.py

* Update train.py

* fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
flybird11111
2024-12-17 15:42:39 +08:00
committed by GitHub
parent e994c64568
commit aaafb38851
18 changed files with 295 additions and 152 deletions

View File

@@ -38,7 +38,7 @@ criterion = lambda x: x.loss
def move_to_cuda(batch):
return {k: v.cuda() for k, v in batch.items()}
return {k: v.to(get_accelerator().get_current_device()) for k, v in batch.items()}
@torch.no_grad()
@@ -266,7 +266,8 @@ def main():
cfg = AutoConfig.from_pretrained(model_name, num_labels=data_builder.num_labels)
if model_name == "bert-base-uncased":
model = BertForSequenceClassification.from_pretrained(model_name, config=cfg).cuda()
model = BertForSequenceClassification.from_pretrained(model_name, config=cfg)
model = model.to(get_accelerator().get_current_device())
elif model_name == "albert-xxlarge-v2":
model = AlbertForSequenceClassification.from_pretrained(model_name, config=cfg)
else: