[shardformer] refactored some doc and api (#4137)

* [shardformer] refactored some doc and api

* polish code
This commit is contained in:
Frank Lee
2023-07-03 15:29:11 +08:00
parent 7f9b30335b
commit 74257cb446
15 changed files with 355 additions and 490 deletions

View File

@@ -48,7 +48,6 @@ class BertPolicy(Policy):
"crossattention.self.num_attention_heads":
self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="attention.self.query",
@@ -88,18 +87,16 @@ class BertPolicy(Policy):
)
]),
BertEmbeddings:
ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="word_embeddings",
target_module=col_nn.VocabParallelEmbedding1D,
),
SubModuleReplacementDescription(
suffix="dropout",
target_module=col_nn.DropoutForReplicatedInput,
)
])
ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="word_embeddings",
target_module=col_nn.VocabParallelEmbedding1D,
),
SubModuleReplacementDescription(
suffix="dropout",
target_module=col_nn.DropoutForReplicatedInput,
)
])
}
# optimization configuration
@@ -121,10 +118,6 @@ class BertPolicy(Policy):
),)
return base_policy
def new_model_class(self):
# do nothing
return None
def postprocess(self):
return self.model
@@ -148,13 +141,10 @@ class BertForPretrainingPolicy(BertPolicy):
module_policy = super().module_policy()
addon_module = {
BertLMPredictionHead:
ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(suffix="decoder",
target_module=col_nn.Linear1D_Col,
kwargs={"gather_output": True}),
])
ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="decoder", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True}),
])
}
# optimization configuration
@@ -191,13 +181,10 @@ class BertLMHeadModelPolicy(BertPolicy):
module_policy = super().module_policy()
addon_module = {
BertLMPredictionHead:
ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(suffix="decoder",
target_module=col_nn.Linear1D_Col,
kwargs={"gather_output": True}),
])
ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="decoder", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True}),
])
}
if self.shard_config.enable_fused_normalization:
addon_module[BertLMPredictionHead].sub_module_replacement.append(
@@ -230,13 +217,10 @@ class BertForMaskedLMPolicy(BertPolicy):
module_policy = super().module_policy()
addon_module = {
BertLMPredictionHead:
ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(suffix="decoder",
target_module=col_nn.Linear1D_Col,
kwargs={"gather_output": True}),
])
ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="decoder", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True}),
])
}
# optimization configuration
@@ -272,14 +256,12 @@ class BertForSequenceClassificationPolicy(BertPolicy):
module_policy = super().module_policy()
addon_module = {
BertForSequenceClassification:
ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="dropout",
target_module=col_nn.DropoutForParallelInput,
)
])
ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="dropout",
target_module=col_nn.DropoutForParallelInput,
)
])
}
module_policy.update(addon_module)
return module_policy
@@ -297,14 +279,12 @@ class BertForTokenClassificationPolicy(BertPolicy):
module_policy = super().module_policy()
addon_module = {
BertForTokenClassification:
ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="dropout",
target_module=col_nn.DropoutForParallelInput,
)
])
ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="dropout",
target_module=col_nn.DropoutForParallelInput,
)
])
}
module_policy.update(addon_module)
return module_policy
@@ -329,14 +309,12 @@ class BertForMultipleChoicePolicy(BertPolicy):
module_policy = super().module_policy()
addon_module = {
BertForMultipleChoice:
ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="dropout",
target_module=col_nn.DropoutForParallelInput,
)
])
ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="dropout",
target_module=col_nn.DropoutForParallelInput,
)
])
}
module_policy.update(addon_module)
return module_policy