[shardformer]delete xformers (#5859)

* delete xformers

* fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
flybird11111
2024-06-28 11:20:04 +08:00
committed by GitHub
parent eaea88cf9e
commit 773d9f964a
7 changed files with 7 additions and 412 deletions

View File

@@ -11,14 +11,13 @@ import colossalai.shardformer.layer as col_nn
from ..modeling.bloom import (
BloomPipelineForwards,
build_bloom_alibi_tensor_fn,
get_bloom_flash_attention_forward,
get_bloom_sequence_parallel_forward_fn,
get_jit_fused_bloom_attention_forward,
get_jit_fused_bloom_gelu_forward,
get_jit_fused_bloom_mlp_forward,
get_lm_forward_with_dist_cross_entropy,
)
from ..modeling.jit import get_dropout_add_func, get_jit_fused_dropout_add_func, get_jit_fused_gelu_forward_func
from ..modeling.jit import get_jit_fused_dropout_add_func, get_jit_fused_gelu_forward_func
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
@@ -165,16 +164,6 @@ class BloomPolicy(Policy):
target_key=BloomModel,
)
if self.shard_config.enable_flash_attention:
self.append_or_create_method_replacement(
description={
"forward": get_bloom_flash_attention_forward(),
"dropout_add": get_dropout_add_func(),
},
policy=policy,
target_key=BloomAttention,
)
# enable jit fused operator
if self.shard_config.enable_jit_fused:
self.append_or_create_method_replacement(