[shardformer] refactor pipeline grad ckpt config (#5646)

* [shardformer] refactor pipeline grad ckpt config

* [shardformer] refactor pipeline grad ckpt config

* [pipeline] fix stage manager
This commit is contained in:
Hongxin Liu
2024-04-25 15:19:30 +08:00
committed by GitHub
parent 7ef91606e1
commit 1b387ca9fe
11 changed files with 59 additions and 102 deletions

View File

@@ -28,6 +28,7 @@ class SubModuleReplacementDescription:
kwargs (Dict[str, Any]): the dictionary used to pass extra arguments to the `ParallelModule.from_native_module` method.
ignore_if_not_exist (bool): if the submodule does not exist, ignore it or raise an exception
"""
suffix: str
target_module: Union[ParallelModule, BaseLayerNorm]
kwargs: Dict[str, Any] = None
@@ -54,6 +55,7 @@ class ModulePolicyDescription:
object which specifies the module to be replaced and the target module used to replacement.
method_replace (Dict[str, Callable]): key is the method name, value is the method for replacement
"""
attribute_replacement: Dict[str, Any] = None
param_replacement: List[Callable] = None
sub_module_replacement: List[SubModuleReplacementDescription] = None