[shardformer] refactored layernorm (#4086)

This commit is contained in:
Frank Lee
2023-06-26 18:05:00 +08:00
parent c4b1b65931
commit d33a44e8c3
4 changed files with 51 additions and 77 deletions

View File

@@ -103,17 +103,17 @@ class BertPolicy(Policy):
base_policy[BertLayer].sub_module_replacement.append(
SubModuleReplacementDescription(
suffix="attention.output.LayerNorm",
target_module=col_nn.LayerNorm1D,
target_module=col_nn.FusedLayerNorm,
))
base_policy[BertLayer].sub_module_replacement.append(
SubModuleReplacementDescription(
suffix="output.LayerNorm",
target_module=col_nn.LayerNorm1D,
target_module=col_nn.FusedLayerNorm,
))
base_policy[BertEmbeddings].sub_module_replacement.append(
SubModuleReplacementDescription(
suffix="LayerNorm",
target_module=col_nn.LayerNorm1D,
target_module=col_nn.FusedLayerNorm,
),)
return base_policy
@@ -154,7 +154,7 @@ class BertForPretrainingPolicy(BertPolicy):
addon_module[BertLMPredictionHead].sub_module_replacement.append(
SubModuleReplacementDescription(
suffix="transform.LayerNorm",
target_module=col_nn.LayerNorm1D,
target_module=col_nn.FusedLayerNorm,
))
module_policy.update(addon_module)
return module_policy
@@ -191,7 +191,7 @@ class BertLMHeadModelPolicy(BertPolicy):
addon_module[BertLMPredictionHead].sub_module_replacement.append(
SubModuleReplacementDescription(
suffix="transform.LayerNorm",
target_module=col_nn.LayerNorm1D,
target_module=col_nn.FusedLayerNorm,
))
module_policy.update(addon_module)
return module_policy
@@ -228,7 +228,7 @@ class BertForMaskedLMPolicy(BertPolicy):
addon_module[BertLMPredictionHead].sub_module_replacement.append(
SubModuleReplacementDescription(
suffix="transform.LayerNorm",
target_module=col_nn.LayerNorm1D,
target_module=col_nn.FusedLayerNorm,
))
module_policy.update(addon_module)
return module_policy