mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-27 20:46:00 +00:00
[shardformer] refactored some doc and api (#4137)
* [shardformer] refactored some doc and api * polish code
This commit is contained in:
@@ -44,36 +44,30 @@ class T5BasePolicy(Policy):
|
||||
|
||||
base_policy = {
|
||||
T5Stack:
|
||||
ModulePolicyDescription(attribute_replacement={},
|
||||
param_replacement=[],
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=DropoutForParallelInput,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="embed_tokens",
|
||||
target_module=Embedding1D,
|
||||
)
|
||||
]),
|
||||
ModulePolicyDescription(sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=DropoutForParallelInput,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="embed_tokens",
|
||||
target_module=Embedding1D,
|
||||
)
|
||||
]),
|
||||
T5LayerSelfAttention:
|
||||
ModulePolicyDescription(attribute_replacement={},
|
||||
param_replacement=[],
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=DropoutForParallelInput,
|
||||
),
|
||||
]),
|
||||
ModulePolicyDescription(sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=DropoutForParallelInput,
|
||||
),
|
||||
]),
|
||||
T5LayerCrossAttention:
|
||||
ModulePolicyDescription(attribute_replacement={},
|
||||
param_replacement=[],
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=DropoutForParallelInput,
|
||||
)
|
||||
]),
|
||||
ModulePolicyDescription(sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=DropoutForParallelInput,
|
||||
)
|
||||
]),
|
||||
T5Attention:
|
||||
ModulePolicyDescription(attribute_replacement={
|
||||
"d_model":
|
||||
@@ -83,7 +77,6 @@ class T5BasePolicy(Policy):
|
||||
"inner_dim":
|
||||
self.model.config.num_heads * self.model.config.d_kv // self.shard_config.tensor_parallel_size
|
||||
},
|
||||
param_replacement=[],
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="q",
|
||||
@@ -107,51 +100,44 @@ class T5BasePolicy(Policy):
|
||||
ignore_if_not_exist=True)
|
||||
]),
|
||||
T5LayerFF:
|
||||
ModulePolicyDescription(attribute_replacement={},
|
||||
param_replacement=[],
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=DropoutForParallelInput,
|
||||
),
|
||||
]),
|
||||
ModulePolicyDescription(sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=DropoutForParallelInput,
|
||||
),
|
||||
]),
|
||||
T5DenseGatedActDense:
|
||||
ModulePolicyDescription(attribute_replacement={},
|
||||
param_replacement=[],
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="wi_0",
|
||||
target_module=Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="wi_1",
|
||||
target_module=Linear1D_Row,
|
||||
),
|
||||
SubModuleReplacementDescription(suffix="wo",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs=dict(gather_output=True)),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=DropoutForParallelInput,
|
||||
)
|
||||
]),
|
||||
ModulePolicyDescription(sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="wi_0",
|
||||
target_module=Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="wi_1",
|
||||
target_module=Linear1D_Row,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="wo", target_module=Linear1D_Col, kwargs=dict(gather_output=True)),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=DropoutForParallelInput,
|
||||
)
|
||||
]),
|
||||
T5DenseActDense:
|
||||
ModulePolicyDescription(attribute_replacement={},
|
||||
param_replacement=[],
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="wi",
|
||||
target_module=Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="wo",
|
||||
target_module=Linear1D_Row,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=DropoutForParallelInput,
|
||||
)
|
||||
])
|
||||
ModulePolicyDescription(sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="wi",
|
||||
target_module=Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="wo",
|
||||
target_module=Linear1D_Row,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=DropoutForParallelInput,
|
||||
)
|
||||
])
|
||||
}
|
||||
|
||||
# optimization configuration
|
||||
@@ -167,9 +153,6 @@ class T5BasePolicy(Policy):
|
||||
|
||||
return base_policy
|
||||
|
||||
def new_model_class(self):
|
||||
return None
|
||||
|
||||
def postprocess(self):
|
||||
binding_map = [["shared", "encoder.embed_tokens"], ["shared", "decoder.embed_tokens"]]
|
||||
|
||||
@@ -185,14 +168,12 @@ class T5ModelPolicy(T5BasePolicy):
|
||||
from transformers import T5Model
|
||||
|
||||
base_policy = super().module_policy()
|
||||
base_policy[T5Model] = ModulePolicyDescription(attribute_replacement={},
|
||||
param_replacement=[],
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="shared",
|
||||
target_module=VocabParallelEmbedding1D,
|
||||
)
|
||||
])
|
||||
base_policy[T5Model] = ModulePolicyDescription(sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="shared",
|
||||
target_module=VocabParallelEmbedding1D,
|
||||
)
|
||||
])
|
||||
return base_policy
|
||||
|
||||
|
||||
@@ -202,18 +183,14 @@ class T5ForConditionalGenerationPolicy(T5BasePolicy):
|
||||
from transformers import T5ForConditionalGeneration
|
||||
|
||||
policy = super().module_policy()
|
||||
policy[T5ForConditionalGeneration] = ModulePolicyDescription(attribute_replacement={},
|
||||
param_replacement=[],
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="shared",
|
||||
target_module=VocabParallelEmbedding1D,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="lm_head",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs=dict(gather_output=True))
|
||||
])
|
||||
policy[T5ForConditionalGeneration] = ModulePolicyDescription(sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="shared",
|
||||
target_module=VocabParallelEmbedding1D,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True))
|
||||
])
|
||||
return policy
|
||||
|
||||
def postprocess(self):
|
||||
@@ -235,14 +212,12 @@ class T5EncoderPolicy(T5BasePolicy):
|
||||
from transformers import T5EncoderModel
|
||||
|
||||
base_policy = super().module_policy()
|
||||
base_policy[T5EncoderModel] = ModulePolicyDescription(attribute_replacement={},
|
||||
param_replacement=[],
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="shared",
|
||||
target_module=VocabParallelEmbedding1D,
|
||||
)
|
||||
])
|
||||
base_policy[T5EncoderModel] = ModulePolicyDescription(sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="shared",
|
||||
target_module=VocabParallelEmbedding1D,
|
||||
)
|
||||
])
|
||||
return base_policy
|
||||
|
||||
def postprocess(self):
|
||||
|
Reference in New Issue
Block a user