[shardformer] supported fused normalization (#4112)

This commit is contained in:
Frank Lee
2023-06-30 09:32:37 +08:00
parent b1c2901530
commit f3b6aaa6b7
12 changed files with 207 additions and 31 deletions

View File

@@ -16,6 +16,9 @@ from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDes
class BertPolicy(Policy):
def config_sanity_check(self):
pass
def preprocess(self):
# reshape the embedding layer
r"""
@@ -99,7 +102,8 @@ class BertPolicy(Policy):
])
}
if self.shard_config.fused_layernorm:
# optimization configuration
if self.shard_config.enable_fused_normalization:
base_policy[BertLayer].sub_module_replacement.append(
SubModuleReplacementDescription(
suffix="attention.output.LayerNorm",
@@ -150,12 +154,16 @@ class BertForPretrainingPolicy(BertPolicy):
kwargs={"gather_output": True}),
])
}
if self.shard_config.fused_layernorm:
# optimization configuration
if self.shard_config.enable_fused_normalization:
addon_module[BertLMPredictionHead].sub_module_replacement.append(
SubModuleReplacementDescription(
suffix="transform.LayerNorm",
target_module=col_nn.FusedLayerNorm,
))
# append extra policy
module_policy.update(addon_module)
return module_policy
@@ -187,7 +195,7 @@ class BertLMHeadModelPolicy(BertPolicy):
kwargs={"gather_output": True}),
])
}
if self.shard_config.fused_layernorm:
if self.shard_config.enable_fused_normalization:
addon_module[BertLMPredictionHead].sub_module_replacement.append(
SubModuleReplacementDescription(
suffix="transform.LayerNorm",
@@ -224,12 +232,15 @@ class BertForMaskedLMPolicy(BertPolicy):
kwargs={"gather_output": True}),
])
}
if self.shard_config.fused_layernorm:
# optimization configuration
if self.shard_config.enable_fused_normalization:
addon_module[BertLMPredictionHead].sub_module_replacement.append(
SubModuleReplacementDescription(
suffix="transform.LayerNorm",
target_module=col_nn.FusedLayerNorm,
))
module_policy.update(addon_module)
return module_policy
@@ -316,4 +327,4 @@ class BertForMultipleChoicePolicy(BertPolicy):
])
}
module_policy.update(addon_module)
return module_policy
return module_policy