mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-26 20:23:26 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -12,7 +12,7 @@ from ..policies.base_policy import Policy, SubModuleReplacementDescription
|
||||
from .shard_config import ShardConfig
|
||||
from .utils import set_tensors_to_none
|
||||
|
||||
__all__ = ['ModelSharder', 'shard_model']
|
||||
__all__ = ["ModelSharder", "shard_model"]
|
||||
|
||||
|
||||
class ModelSharder(object):
|
||||
@@ -64,13 +64,15 @@ class ModelSharder(object):
|
||||
param_replacement = module_description.param_replacement
|
||||
sub_module_replacement = module_description.sub_module_replacement
|
||||
method_replacement = module_description.method_replacement
|
||||
self._recursive_replace_layer(self.model,
|
||||
layer_cls,
|
||||
attr_replacement,
|
||||
param_replacement,
|
||||
method_replacement,
|
||||
sub_module_replacement,
|
||||
include=include)
|
||||
self._recursive_replace_layer(
|
||||
self.model,
|
||||
layer_cls,
|
||||
attr_replacement,
|
||||
param_replacement,
|
||||
method_replacement,
|
||||
sub_module_replacement,
|
||||
include=include,
|
||||
)
|
||||
|
||||
def _recursive_replace_layer(
|
||||
self,
|
||||
@@ -94,8 +96,9 @@ class ModelSharder(object):
|
||||
sub_module_replacement ((List[SubModuleReplacementDescription]): The function list to get sub module shard information in policy
|
||||
include (Set[nn.Module], optional): The set of modules to keep on current device when pipeline parallel is enabled. Defaults to None
|
||||
"""
|
||||
if (isinstance(origin_cls, str) and origin_cls == module.__class__.__name__) or \
|
||||
(module.__class__ == origin_cls):
|
||||
if (isinstance(origin_cls, str) and origin_cls == module.__class__.__name__) or (
|
||||
module.__class__ == origin_cls
|
||||
):
|
||||
if attr_replacement is not None:
|
||||
self._replace_attr(module, attr_replacement)
|
||||
|
||||
@@ -109,13 +112,15 @@ class ModelSharder(object):
|
||||
self._replace_sub_module(module, sub_module_replacement, include)
|
||||
|
||||
for name, child in module.named_children():
|
||||
self._recursive_replace_layer(child,
|
||||
origin_cls,
|
||||
attr_replacement,
|
||||
param_replacement,
|
||||
method_replacement,
|
||||
sub_module_replacement,
|
||||
include=include)
|
||||
self._recursive_replace_layer(
|
||||
child,
|
||||
origin_cls,
|
||||
attr_replacement,
|
||||
param_replacement,
|
||||
method_replacement,
|
||||
sub_module_replacement,
|
||||
include=include,
|
||||
)
|
||||
|
||||
def _replace_attr(
|
||||
self,
|
||||
@@ -153,10 +158,12 @@ class ModelSharder(object):
|
||||
bound_method = MethodType(new_method, module)
|
||||
setattr(module, method_name, bound_method)
|
||||
|
||||
def _replace_sub_module(self,
|
||||
org_layer: nn.Module,
|
||||
sub_module_replacement: List[SubModuleReplacementDescription],
|
||||
include: Optional[Set[nn.Module]] = None) -> None:
|
||||
def _replace_sub_module(
|
||||
self,
|
||||
org_layer: nn.Module,
|
||||
sub_module_replacement: List[SubModuleReplacementDescription],
|
||||
include: Optional[Set[nn.Module]] = None,
|
||||
) -> None:
|
||||
r"""
|
||||
Shard one layer according to the policy, the layer should be the same class as the key in policy's argument_policy return dict
|
||||
|
||||
@@ -170,7 +177,7 @@ class ModelSharder(object):
|
||||
target_module = description.target_module
|
||||
kwargs = {} if description.kwargs is None else description.kwargs
|
||||
|
||||
assert target_module is not None, 'target_module should not be None'
|
||||
assert target_module is not None, "target_module should not be None"
|
||||
|
||||
native_sub_module = getattr_(org_layer, suffix, ignore=True)
|
||||
|
||||
@@ -178,8 +185,9 @@ class ModelSharder(object):
|
||||
if (include is not None) and (native_sub_module is not None) and (native_sub_module not in include):
|
||||
continue
|
||||
|
||||
assert not isinstance(native_sub_module, target_module), \
|
||||
f"The module with suffix {suffix} has been replaced, please check the policy"
|
||||
assert not isinstance(
|
||||
native_sub_module, target_module
|
||||
), f"The module with suffix {suffix} has been replaced, please check the policy"
|
||||
|
||||
# if it is None and we are allowed to ignore this module
|
||||
# just skip
|
||||
@@ -187,9 +195,9 @@ class ModelSharder(object):
|
||||
continue
|
||||
|
||||
try:
|
||||
replace_layer = target_module.from_native_module(native_sub_module,
|
||||
self.shard_config.tensor_parallel_process_group,
|
||||
**kwargs)
|
||||
replace_layer = target_module.from_native_module(
|
||||
native_sub_module, self.shard_config.tensor_parallel_process_group, **kwargs
|
||||
)
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"Failed to replace {suffix} of type {native_sub_module.__class__.__qualname__}"
|
||||
@@ -200,7 +208,6 @@ class ModelSharder(object):
|
||||
setattr_(org_layer, suffix, replace_layer)
|
||||
|
||||
def _get_recursive_held_layers(self, held_layers: Optional[List[nn.Module]]) -> Optional[List[nn.Module]]:
|
||||
|
||||
def collect_sub_modules(module: nn.Module):
|
||||
if module is None:
|
||||
return
|
||||
|
Reference in New Issue
Block a user