[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
This commit is contained in:
Hongxin Liu
2023-09-19 14:20:26 +08:00
committed by GitHub
parent 3c6b831c26
commit 079bf3cb26
1268 changed files with 50037 additions and 38444 deletions

View File

@@ -12,7 +12,7 @@ from ..policies.base_policy import Policy, SubModuleReplacementDescription
from .shard_config import ShardConfig
from .utils import set_tensors_to_none
__all__ = ['ModelSharder', 'shard_model']
__all__ = ["ModelSharder", "shard_model"]
class ModelSharder(object):
@@ -64,13 +64,15 @@ class ModelSharder(object):
param_replacement = module_description.param_replacement
sub_module_replacement = module_description.sub_module_replacement
method_replacement = module_description.method_replacement
self._recursive_replace_layer(self.model,
layer_cls,
attr_replacement,
param_replacement,
method_replacement,
sub_module_replacement,
include=include)
self._recursive_replace_layer(
self.model,
layer_cls,
attr_replacement,
param_replacement,
method_replacement,
sub_module_replacement,
include=include,
)
def _recursive_replace_layer(
self,
@@ -94,8 +96,9 @@ class ModelSharder(object):
sub_module_replacement ((List[SubModuleReplacementDescription]): The function list to get sub module shard information in policy
include (Set[nn.Module], optional): The set of modules to keep on current device when pipeline parallel is enabled. Defaults to None
"""
if (isinstance(origin_cls, str) and origin_cls == module.__class__.__name__) or \
(module.__class__ == origin_cls):
if (isinstance(origin_cls, str) and origin_cls == module.__class__.__name__) or (
module.__class__ == origin_cls
):
if attr_replacement is not None:
self._replace_attr(module, attr_replacement)
@@ -109,13 +112,15 @@ class ModelSharder(object):
self._replace_sub_module(module, sub_module_replacement, include)
for name, child in module.named_children():
self._recursive_replace_layer(child,
origin_cls,
attr_replacement,
param_replacement,
method_replacement,
sub_module_replacement,
include=include)
self._recursive_replace_layer(
child,
origin_cls,
attr_replacement,
param_replacement,
method_replacement,
sub_module_replacement,
include=include,
)
def _replace_attr(
self,
@@ -153,10 +158,12 @@ class ModelSharder(object):
bound_method = MethodType(new_method, module)
setattr(module, method_name, bound_method)
def _replace_sub_module(self,
org_layer: nn.Module,
sub_module_replacement: List[SubModuleReplacementDescription],
include: Optional[Set[nn.Module]] = None) -> None:
def _replace_sub_module(
self,
org_layer: nn.Module,
sub_module_replacement: List[SubModuleReplacementDescription],
include: Optional[Set[nn.Module]] = None,
) -> None:
r"""
Shard one layer according to the policy, the layer should be the same class as the key in policy's argument_policy return dict
@@ -170,7 +177,7 @@ class ModelSharder(object):
target_module = description.target_module
kwargs = {} if description.kwargs is None else description.kwargs
assert target_module is not None, 'target_module should not be None'
assert target_module is not None, "target_module should not be None"
native_sub_module = getattr_(org_layer, suffix, ignore=True)
@@ -178,8 +185,9 @@ class ModelSharder(object):
if (include is not None) and (native_sub_module is not None) and (native_sub_module not in include):
continue
assert not isinstance(native_sub_module, target_module), \
f"The module with suffix {suffix} has been replaced, please check the policy"
assert not isinstance(
native_sub_module, target_module
), f"The module with suffix {suffix} has been replaced, please check the policy"
# if it is None and we are allowed to ignore this module
# just skip
@@ -187,9 +195,9 @@ class ModelSharder(object):
continue
try:
replace_layer = target_module.from_native_module(native_sub_module,
self.shard_config.tensor_parallel_process_group,
**kwargs)
replace_layer = target_module.from_native_module(
native_sub_module, self.shard_config.tensor_parallel_process_group, **kwargs
)
except Exception as e:
raise RuntimeError(
f"Failed to replace {suffix} of type {native_sub_module.__class__.__qualname__}"
@@ -200,7 +208,6 @@ class ModelSharder(object):
setattr_(org_layer, suffix, replace_layer)
def _get_recursive_held_layers(self, held_layers: Optional[List[nn.Module]]) -> Optional[List[nn.Module]]:
def collect_sub_modules(module: nn.Module):
if module is None:
return