[shardformer] fix submodule replacement bug when enabling pp (#4544)

This commit is contained in:
Baizhou Zhang
2023-08-31 09:57:18 +08:00
committed by GitHub
parent ec18fc7340
commit 2c787d7f47
5 changed files with 21 additions and 12 deletions

View File

@@ -92,22 +92,21 @@ class ModelSharder(object):
param_replacement (List[Callable]): The function list to get parameter shard information in policy
method_replacement (Dict[str, Callable]): Key is the method name, value is the method for replacement
sub_module_replacement ((List[SubModuleReplacementDescription]): The function list to get sub module shard information in policy
include (Set[nn.Module], optional): The set of modules to keep on current device when pipeline parallel is enabled. Defaults to None
"""
# released layers are not shardable
can_replace_param_or_layer = include is None or module in include
if (isinstance(origin_cls, str) and origin_cls == module.__class__.__name__) or \
(module.__class__ == origin_cls):
if attr_replacement is not None:
self._replace_attr(module, attr_replacement)
if param_replacement is not None and can_replace_param_or_layer:
if param_replacement is not None and (include is None or module in include):
self._replace_param(module, param_replacement)
if method_replacement is not None:
self._replace_method(module, method_replacement)
if sub_module_replacement is not None and can_replace_param_or_layer:
self._replace_sub_module(module, sub_module_replacement)
if sub_module_replacement is not None:
self._replace_sub_module(module, sub_module_replacement, include)
for name, child in module.named_children():
self._recursive_replace_layer(child,
@@ -154,18 +153,17 @@ class ModelSharder(object):
bound_method = MethodType(new_method, module)
setattr(module, method_name, bound_method)
def _replace_sub_module(
self,
org_layer: nn.Module,
sub_module_replacement: List[SubModuleReplacementDescription],
) -> None:
def _replace_sub_module(self,
org_layer: nn.Module,
sub_module_replacement: List[SubModuleReplacementDescription],
include: Optional[Set[nn.Module]] = None) -> None:
r"""
Shard one layer according to the policy, the layer should be the same class as the key in policy's argument_policy return dict
Args:
org_layer (torch.nn.Module): The origin layer object to shard
sub_module_replacement (List[SubModuleReplacementDescription]): The sub module replacement description list
include (Set[nn.Module], optional): The set of modules to keep on current device when pipeline parallel is enabled. Defaults to None
"""
for description in sub_module_replacement:
suffix = description.suffix
@@ -174,9 +172,12 @@ class ModelSharder(object):
assert target_module is not None, 'target_module should not be None'
# TODO: support different parallel mode
native_sub_module = getattr_(org_layer, suffix, ignore=True)
# Skip replacement if submodule is not kept by current device when pipeline parallel is enabled.
if (include is not None) and (native_sub_module is not None) and (native_sub_module not in include):
continue
assert not isinstance(native_sub_module, target_module), \
f"The module with suffix {suffix} has been replaced, please check the policy"