[shardformer]delete xformers (#5859)

* delete xformers

* fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
flybird11111
2024-06-28 11:20:04 +08:00
committed by GitHub
parent eaea88cf9e
commit 773d9f964a
7 changed files with 7 additions and 412 deletions

View File

@@ -11,7 +11,6 @@ import colossalai.shardformer.layer as col_nn
from ..modeling.bert import (
BertPipelineForwards,
bert_sequence_parallel_forward_fn,
get_bert_flash_attention_forward,
get_jit_fused_bert_intermediate_forward,
get_jit_fused_bert_output_forward,
get_jit_fused_bert_self_output_forward,
@@ -49,7 +48,6 @@ class BertPolicy(Policy):
BertLayer,
BertModel,
BertOutput,
BertSelfAttention,
BertSelfOutput,
)
@@ -218,16 +216,6 @@ class BertPolicy(Policy):
target_key=BertEmbeddings,
)
# use flash attention
if self.shard_config.enable_flash_attention:
self.append_or_create_method_replacement(
description={
"forward": get_bert_flash_attention_forward(),
},
policy=policy,
target_key=BertSelfAttention,
)
# use jit operator
if self.shard_config.enable_jit_fused:
self.append_or_create_method_replacement(