mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-28 04:55:25 +00:00
[shardformer] supported fused normalization (#4112)
This commit is contained in:
@@ -16,6 +16,9 @@ from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDes
|
||||
|
||||
class BertPolicy(Policy):
|
||||
|
||||
def config_sanity_check(self):
|
||||
pass
|
||||
|
||||
def preprocess(self):
|
||||
# reshape the embedding layer
|
||||
r"""
|
||||
@@ -99,7 +102,8 @@ class BertPolicy(Policy):
|
||||
])
|
||||
}
|
||||
|
||||
if self.shard_config.fused_layernorm:
|
||||
# optimization configuration
|
||||
if self.shard_config.enable_fused_normalization:
|
||||
base_policy[BertLayer].sub_module_replacement.append(
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.output.LayerNorm",
|
||||
@@ -150,12 +154,16 @@ class BertForPretrainingPolicy(BertPolicy):
|
||||
kwargs={"gather_output": True}),
|
||||
])
|
||||
}
|
||||
if self.shard_config.fused_layernorm:
|
||||
|
||||
# optimization configuration
|
||||
if self.shard_config.enable_fused_normalization:
|
||||
addon_module[BertLMPredictionHead].sub_module_replacement.append(
|
||||
SubModuleReplacementDescription(
|
||||
suffix="transform.LayerNorm",
|
||||
target_module=col_nn.FusedLayerNorm,
|
||||
))
|
||||
|
||||
# append extra policy
|
||||
module_policy.update(addon_module)
|
||||
return module_policy
|
||||
|
||||
@@ -187,7 +195,7 @@ class BertLMHeadModelPolicy(BertPolicy):
|
||||
kwargs={"gather_output": True}),
|
||||
])
|
||||
}
|
||||
if self.shard_config.fused_layernorm:
|
||||
if self.shard_config.enable_fused_normalization:
|
||||
addon_module[BertLMPredictionHead].sub_module_replacement.append(
|
||||
SubModuleReplacementDescription(
|
||||
suffix="transform.LayerNorm",
|
||||
@@ -224,12 +232,15 @@ class BertForMaskedLMPolicy(BertPolicy):
|
||||
kwargs={"gather_output": True}),
|
||||
])
|
||||
}
|
||||
if self.shard_config.fused_layernorm:
|
||||
|
||||
# optimization configuration
|
||||
if self.shard_config.enable_fused_normalization:
|
||||
addon_module[BertLMPredictionHead].sub_module_replacement.append(
|
||||
SubModuleReplacementDescription(
|
||||
suffix="transform.LayerNorm",
|
||||
target_module=col_nn.FusedLayerNorm,
|
||||
))
|
||||
|
||||
module_policy.update(addon_module)
|
||||
return module_policy
|
||||
|
||||
@@ -316,4 +327,4 @@ class BertForMultipleChoicePolicy(BertPolicy):
|
||||
])
|
||||
}
|
||||
module_policy.update(addon_module)
|
||||
return module_policy
|
||||
return module_policy
|
||||
|
Reference in New Issue
Block a user