[Device]Support npu (#6159)

* support npu

* support pretrain

support pretrain

fix

* support lora

fix

fix

* support chatglm

fix

fxi

fix

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

fix

fix

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

fix

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

fix

fix

fix

* Update train.py

* Update train.py

* fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
flybird11111
2024-12-17 15:42:39 +08:00
committed by GitHub
parent e994c64568
commit aaafb38851
18 changed files with 295 additions and 152 deletions

View File

@@ -11,6 +11,7 @@ from colossalai.shardformer.modeling.chatglm2 import ChatGLMPipelineForwards
from ..modeling.chatglm2 import (
get_chatglm_sequence_parallel_attention_forward,
get_chatglm_sequence_parallel_forward_fn,
get_flash_attention_forward_for_chat_glm_model,
get_flash_core_attention_forward,
get_jit_fused_glm_block_forward,
)
@@ -203,6 +204,13 @@ class ChatGLMPolicy(Policy):
policy=policy,
target_key="CoreAttention",
)
self.append_or_create_method_replacement(
description={
"forward": get_flash_attention_forward_for_chat_glm_model(),
},
policy=policy,
target_key="ChatGLMModel",
)
# use sequence parallel
if self.shard_config.enable_sequence_parallelism: