mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-25 03:31:56 +00:00
[Device]Support npu (#6159)
* support npu * support pretrain support pretrain fix * support lora fix fix * support chatglm fix fxi fix [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci fix fix [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci fix [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci fix fix fix * Update train.py * Update train.py * fix * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * fix * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -11,6 +11,7 @@ from colossalai.shardformer.modeling.chatglm2 import ChatGLMPipelineForwards
|
||||
from ..modeling.chatglm2 import (
|
||||
get_chatglm_sequence_parallel_attention_forward,
|
||||
get_chatglm_sequence_parallel_forward_fn,
|
||||
get_flash_attention_forward_for_chat_glm_model,
|
||||
get_flash_core_attention_forward,
|
||||
get_jit_fused_glm_block_forward,
|
||||
)
|
||||
@@ -203,6 +204,13 @@ class ChatGLMPolicy(Policy):
|
||||
policy=policy,
|
||||
target_key="CoreAttention",
|
||||
)
|
||||
self.append_or_create_method_replacement(
|
||||
description={
|
||||
"forward": get_flash_attention_forward_for_chat_glm_model(),
|
||||
},
|
||||
policy=policy,
|
||||
target_key="ChatGLMModel",
|
||||
)
|
||||
|
||||
# use sequence parallel
|
||||
if self.shard_config.enable_sequence_parallelism:
|
||||
|
Reference in New Issue
Block a user