[shardformer] update colo attention to support custom mask (#5510)

* [feature] refactor colo attention (#5462)

* [extension] update api

* [feature] add colo attention

* [feature] update sdpa

* [feature] update npu attention

* [feature] update flash-attn

* [test] add flash attn test

* [test] update flash attn test

* [shardformer] update modeling to fit colo attention (#5465)

* [misc] refactor folder structure

* [shardformer] update llama flash-attn

* [shardformer] fix llama policy

* [devops] update tensornvme install

* [test] update llama test

* [shardformer] update colo attn kernel dispatch

* [shardformer] update blip2

* [shardformer] update chatglm

* [shardformer] update gpt2

* [shardformer] update gptj

* [shardformer] update opt

* [shardformer] update vit

* [shardformer] update colo attention mask prep

* [shardformer] update whisper

* [test] fix shardformer tests (#5514)

* [test] fix shardformer tests

* [test] fix shardformer tests
This commit is contained in:
Hongxin Liu
2024-03-27 11:19:32 +08:00
committed by GitHub
parent 9a3321e9f4
commit 19e1a5cf16
45 changed files with 2543 additions and 1170 deletions

View File

@@ -9,7 +9,12 @@ from colossalai.shardformer.layer import FusedLayerNorm, LayerNorm, Linear1D_Col
from .._utils import getattr_
from ..modeling.jit import get_jit_fused_dropout_add_func
from ..modeling.opt import OPTPipelineForwards, get_jit_fused_opt_decoder_layer_forward, get_opt_flash_attention_forward
from ..modeling.opt import (
OPTPipelineForwards,
get_jit_fused_opt_decoder_layer_forward,
get_opt_decoder_forward_for_flash_attention,
get_opt_flash_attention_forward,
)
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = [
@@ -27,6 +32,7 @@ class OPTPolicy(Policy):
import transformers
from packaging.version import Version
# TODO: remove this version check when transformers>=4.36.0
assert Version(transformers.__version__) <= Version(
"4.33.0"
), "The OPT model should run on a transformers version not greater than 4.33.0."
@@ -111,7 +117,9 @@ class OPTPolicy(Policy):
# optimization configuration
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="final_layer_norm", target_module=norm_cls, ignore_if_not_exist=True
suffix="final_layer_norm",
target_module=norm_cls,
ignore_if_not_exist=True,
),
policy=policy,
target_key=OPTDecoder,
@@ -119,10 +127,14 @@ class OPTPolicy(Policy):
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="self_attn_layer_norm", target_module=norm_cls, ignore_if_not_exist=True
suffix="self_attn_layer_norm",
target_module=norm_cls,
ignore_if_not_exist=True,
),
SubModuleReplacementDescription(
suffix="final_layer_norm", target_module=norm_cls, ignore_if_not_exist=True
suffix="final_layer_norm",
target_module=norm_cls,
ignore_if_not_exist=True,
),
],
policy=policy,
@@ -133,11 +145,19 @@ class OPTPolicy(Policy):
if self.shard_config.enable_flash_attention:
self.append_or_create_method_replacement(
description={
"forward": get_opt_flash_attention_forward(),
"forward": get_opt_flash_attention_forward(self.shard_config),
},
policy=policy,
target_key=OPTAttention,
)
if not self.shard_config.pipeline_stage_manager:
self.append_or_create_method_replacement(
description={
"forward": get_opt_decoder_forward_for_flash_attention(self.shard_config),
},
policy=policy,
target_key=OPTDecoder,
)
# use jit fused operator
if self.shard_config.enable_jit_fused:
@@ -190,7 +210,14 @@ class OPTPolicy(Policy):
layers_per_stage = Policy.distribute_layers(len(module.layers), stage_manager.num_stages)
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
method_replacement = {"forward": partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)}
method_replacement = {
"forward": partial(
new_forward,
stage_manager=stage_manager,
stage_index=stage_index,
shard_config=self.shard_config,
)
}
self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=model_cls
)
@@ -203,7 +230,9 @@ class OPTModelPolicy(OPTPolicy):
policy = super().module_policy()
if self.pipeline_stage_manager:
self.set_pipeline_forward(
model_cls=OPTModel, new_forward=OPTPipelineForwards.opt_model_forward, policy=policy
model_cls=OPTModel,
new_forward=OPTPipelineForwards.opt_model_forward,
policy=policy,
)
return policy
@@ -223,14 +252,18 @@ class OPTForCausalLMPolicy(OPTPolicy):
if self.shard_config.enable_tensor_parallelism:
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True)
suffix="lm_head",
target_module=Linear1D_Col,
kwargs=dict(gather_output=True),
),
policy=policy,
target_key=OPTForCausalLM,
)
if self.pipeline_stage_manager:
self.set_pipeline_forward(
model_cls=OPTForCausalLM, new_forward=OPTPipelineForwards.opt_for_causal_lm_forward, policy=policy
model_cls=OPTForCausalLM,
new_forward=OPTPipelineForwards.opt_for_causal_lm_forward,
policy=policy,
)
return policy
@@ -246,7 +279,12 @@ class OPTForCausalLMPolicy(OPTPolicy):
if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1:
num_stages = self.pipeline_stage_manager.num_stages
if id(opt_model.model.decoder.embed_tokens.weight) == id(opt_model.lm_head.weight):
return [{0: opt_model.model.decoder.embed_tokens.weight, num_stages - 1: opt_model.lm_head.weight}]
return [
{
0: opt_model.model.decoder.embed_tokens.weight,
num_stages - 1: opt_model.lm_head.weight,
}
]
return []
def postprocess(self):