mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-28 13:05:26 +00:00
[shardformer] refactored some doc and api (#4137)
* [shardformer] refactored some doc and api * polish code
This commit is contained in:
@@ -48,7 +48,6 @@ class BertPolicy(Policy):
|
||||
"crossattention.self.num_attention_heads":
|
||||
self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
|
||||
},
|
||||
param_replacement=[],
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.self.query",
|
||||
@@ -88,18 +87,16 @@ class BertPolicy(Policy):
|
||||
)
|
||||
]),
|
||||
BertEmbeddings:
|
||||
ModulePolicyDescription(attribute_replacement={},
|
||||
param_replacement=[],
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="word_embeddings",
|
||||
target_module=col_nn.VocabParallelEmbedding1D,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=col_nn.DropoutForReplicatedInput,
|
||||
)
|
||||
])
|
||||
ModulePolicyDescription(sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="word_embeddings",
|
||||
target_module=col_nn.VocabParallelEmbedding1D,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=col_nn.DropoutForReplicatedInput,
|
||||
)
|
||||
])
|
||||
}
|
||||
|
||||
# optimization configuration
|
||||
@@ -121,10 +118,6 @@ class BertPolicy(Policy):
|
||||
),)
|
||||
return base_policy
|
||||
|
||||
def new_model_class(self):
|
||||
# do nothing
|
||||
return None
|
||||
|
||||
def postprocess(self):
|
||||
return self.model
|
||||
|
||||
@@ -148,13 +141,10 @@ class BertForPretrainingPolicy(BertPolicy):
|
||||
module_policy = super().module_policy()
|
||||
addon_module = {
|
||||
BertLMPredictionHead:
|
||||
ModulePolicyDescription(attribute_replacement={},
|
||||
param_replacement=[],
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(suffix="decoder",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={"gather_output": True}),
|
||||
])
|
||||
ModulePolicyDescription(sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="decoder", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True}),
|
||||
])
|
||||
}
|
||||
|
||||
# optimization configuration
|
||||
@@ -191,13 +181,10 @@ class BertLMHeadModelPolicy(BertPolicy):
|
||||
module_policy = super().module_policy()
|
||||
addon_module = {
|
||||
BertLMPredictionHead:
|
||||
ModulePolicyDescription(attribute_replacement={},
|
||||
param_replacement=[],
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(suffix="decoder",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={"gather_output": True}),
|
||||
])
|
||||
ModulePolicyDescription(sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="decoder", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True}),
|
||||
])
|
||||
}
|
||||
if self.shard_config.enable_fused_normalization:
|
||||
addon_module[BertLMPredictionHead].sub_module_replacement.append(
|
||||
@@ -230,13 +217,10 @@ class BertForMaskedLMPolicy(BertPolicy):
|
||||
module_policy = super().module_policy()
|
||||
addon_module = {
|
||||
BertLMPredictionHead:
|
||||
ModulePolicyDescription(attribute_replacement={},
|
||||
param_replacement=[],
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(suffix="decoder",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={"gather_output": True}),
|
||||
])
|
||||
ModulePolicyDescription(sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="decoder", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True}),
|
||||
])
|
||||
}
|
||||
|
||||
# optimization configuration
|
||||
@@ -272,14 +256,12 @@ class BertForSequenceClassificationPolicy(BertPolicy):
|
||||
module_policy = super().module_policy()
|
||||
addon_module = {
|
||||
BertForSequenceClassification:
|
||||
ModulePolicyDescription(attribute_replacement={},
|
||||
param_replacement=[],
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=col_nn.DropoutForParallelInput,
|
||||
)
|
||||
])
|
||||
ModulePolicyDescription(sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=col_nn.DropoutForParallelInput,
|
||||
)
|
||||
])
|
||||
}
|
||||
module_policy.update(addon_module)
|
||||
return module_policy
|
||||
@@ -297,14 +279,12 @@ class BertForTokenClassificationPolicy(BertPolicy):
|
||||
module_policy = super().module_policy()
|
||||
addon_module = {
|
||||
BertForTokenClassification:
|
||||
ModulePolicyDescription(attribute_replacement={},
|
||||
param_replacement=[],
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=col_nn.DropoutForParallelInput,
|
||||
)
|
||||
])
|
||||
ModulePolicyDescription(sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=col_nn.DropoutForParallelInput,
|
||||
)
|
||||
])
|
||||
}
|
||||
module_policy.update(addon_module)
|
||||
return module_policy
|
||||
@@ -329,14 +309,12 @@ class BertForMultipleChoicePolicy(BertPolicy):
|
||||
module_policy = super().module_policy()
|
||||
addon_module = {
|
||||
BertForMultipleChoice:
|
||||
ModulePolicyDescription(attribute_replacement={},
|
||||
param_replacement=[],
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=col_nn.DropoutForParallelInput,
|
||||
)
|
||||
])
|
||||
ModulePolicyDescription(sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=col_nn.DropoutForParallelInput,
|
||||
)
|
||||
])
|
||||
}
|
||||
module_policy.update(addon_module)
|
||||
return module_policy
|
||||
|
Reference in New Issue
Block a user