mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 21:09:18 +00:00
[shardformer] refactored layernorm (#4086)
This commit is contained in:
@@ -103,17 +103,17 @@ class BertPolicy(Policy):
|
||||
base_policy[BertLayer].sub_module_replacement.append(
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.output.LayerNorm",
|
||||
target_module=col_nn.LayerNorm1D,
|
||||
target_module=col_nn.FusedLayerNorm,
|
||||
))
|
||||
base_policy[BertLayer].sub_module_replacement.append(
|
||||
SubModuleReplacementDescription(
|
||||
suffix="output.LayerNorm",
|
||||
target_module=col_nn.LayerNorm1D,
|
||||
target_module=col_nn.FusedLayerNorm,
|
||||
))
|
||||
base_policy[BertEmbeddings].sub_module_replacement.append(
|
||||
SubModuleReplacementDescription(
|
||||
suffix="LayerNorm",
|
||||
target_module=col_nn.LayerNorm1D,
|
||||
target_module=col_nn.FusedLayerNorm,
|
||||
),)
|
||||
return base_policy
|
||||
|
||||
@@ -154,7 +154,7 @@ class BertForPretrainingPolicy(BertPolicy):
|
||||
addon_module[BertLMPredictionHead].sub_module_replacement.append(
|
||||
SubModuleReplacementDescription(
|
||||
suffix="transform.LayerNorm",
|
||||
target_module=col_nn.LayerNorm1D,
|
||||
target_module=col_nn.FusedLayerNorm,
|
||||
))
|
||||
module_policy.update(addon_module)
|
||||
return module_policy
|
||||
@@ -191,7 +191,7 @@ class BertLMHeadModelPolicy(BertPolicy):
|
||||
addon_module[BertLMPredictionHead].sub_module_replacement.append(
|
||||
SubModuleReplacementDescription(
|
||||
suffix="transform.LayerNorm",
|
||||
target_module=col_nn.LayerNorm1D,
|
||||
target_module=col_nn.FusedLayerNorm,
|
||||
))
|
||||
module_policy.update(addon_module)
|
||||
return module_policy
|
||||
@@ -228,7 +228,7 @@ class BertForMaskedLMPolicy(BertPolicy):
|
||||
addon_module[BertLMPredictionHead].sub_module_replacement.append(
|
||||
SubModuleReplacementDescription(
|
||||
suffix="transform.LayerNorm",
|
||||
target_module=col_nn.LayerNorm1D,
|
||||
target_module=col_nn.FusedLayerNorm,
|
||||
))
|
||||
module_policy.update(addon_module)
|
||||
return module_policy
|
||||
|
Reference in New Issue
Block a user