mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-26 07:22:12 +00:00
[context]use meta tensor to init model lazily. (#1187)
* [CLI] add CLI launcher
* Revert "[CLI] add CLI launcher"
This reverts commit df7e6506d4
.
* [context]use meta tensor to init model lazily.
* polish
* make module with device kwargs bypass the normal init.
* change unit test to adapt updated context.
This commit is contained in:
parent
2c8c05675d
commit
2053e138a2
@ -7,6 +7,8 @@ import types
|
|||||||
import inspect
|
import inspect
|
||||||
import typing
|
import typing
|
||||||
from typing import List, Callable
|
from typing import List, Callable
|
||||||
|
from colossalai.utils.model.utils import substitute_init_recursively
|
||||||
|
|
||||||
|
|
||||||
class LazyInitContext():
|
class LazyInitContext():
|
||||||
"""
|
"""
|
||||||
@ -36,29 +38,31 @@ class LazyInitContext():
|
|||||||
extra_torch_tensor_func (List[str]): extra torch tensor functions related
|
extra_torch_tensor_func (List[str]): extra torch tensor functions related
|
||||||
to value setting, such as `zero_` and `triu_`. `zero_` is pre-added by default.
|
to value setting, such as `zero_` and `triu_`. `zero_` is pre-added by default.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
tensor_set_value_func = ['zero_']
|
tensor_set_value_func = ['zero_']
|
||||||
|
|
||||||
def __init__(self, extra_torch_tensor_func: List[str] = None):
|
def __init__(self, extra_torch_tensor_func: List[str] = None):
|
||||||
self._intercepted_init_func_cache = []
|
self._intercepted_init_func_cache = []
|
||||||
self._nn_init_methods = self._get_nn_init_methods()
|
self._nn_init_methods = self._get_nn_init_methods()
|
||||||
self._torch_mod_cls = torch.nn.modules.module.Module
|
self._torch_mod_cls = torch.nn.modules.module.Module
|
||||||
|
|
||||||
if extra_torch_tensor_func:
|
if extra_torch_tensor_func:
|
||||||
# use tuple to remove duplicates
|
# use tuple to remove duplicates
|
||||||
self._torch_tensor_funcs = tuple(self.tensor_set_value_func + extra_torch_tensor_func)
|
self._torch_tensor_funcs = tuple(self.tensor_set_value_func + extra_torch_tensor_func)
|
||||||
else:
|
else:
|
||||||
self._torch_tensor_funcs = self.tensor_set_value_func
|
self._torch_tensor_funcs = self.tensor_set_value_func
|
||||||
|
|
||||||
def _cache_func(self, func):
|
def _cache_func(self, func):
|
||||||
"""
|
"""
|
||||||
This method wraps the ``torch.nn.init`` method so that the function call
|
This method wraps the ``torch.nn.init`` method so that the function call
|
||||||
is cached instead of being executed.
|
is cached instead of being executed.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def wrapped_init_func(*args, **kwargs):
|
def wrapped_init_func(*args, **kwargs):
|
||||||
self._intercepted_init_func_cache.append(dict(func=func, args=args, kwargs=kwargs))
|
self._intercepted_init_func_cache.append(dict(func=func, args=args, kwargs=kwargs))
|
||||||
|
|
||||||
return wrapped_init_func
|
return wrapped_init_func
|
||||||
|
|
||||||
def _get_nn_init_methods(self):
|
def _get_nn_init_methods(self):
|
||||||
"""
|
"""
|
||||||
This method looks for all available functions in the ``torch.nn.init``
|
This method looks for all available functions in the ``torch.nn.init``
|
||||||
@ -66,32 +70,30 @@ class LazyInitContext():
|
|||||||
"""
|
"""
|
||||||
nn_init_method_names = dir(torch.nn.init)
|
nn_init_method_names = dir(torch.nn.init)
|
||||||
nn_init_methods = []
|
nn_init_methods = []
|
||||||
|
|
||||||
# look for all methods in ``torch.nn.init`` module
|
# look for all methods in ``torch.nn.init`` module
|
||||||
for name in nn_init_method_names:
|
for name in nn_init_method_names:
|
||||||
nn_init_methods.append((name, getattr(torch.nn.init, name)))
|
nn_init_methods.append((name, getattr(torch.nn.init, name)))
|
||||||
|
|
||||||
def _has_tensor_in_arg(func):
|
def _has_tensor_in_arg(func):
|
||||||
hints = typing.get_type_hints(torch.nn.init.normal_)
|
hints = typing.get_type_hints(func)
|
||||||
for k, v in hints.items():
|
for k, v in hints.items():
|
||||||
if v is torch.Tensor:
|
if v is torch.Tensor:
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _is_init_method(item):
|
def _is_init_method(item):
|
||||||
name, func = item
|
name, func = item
|
||||||
if (not isinstance(func, types.FunctionType) or
|
if (not isinstance(func, types.FunctionType) or name.startswith('_') or not name.endswith('_')
|
||||||
name.startswith('_') or
|
or not _has_tensor_in_arg(func)):
|
||||||
not name.endswith('_') or
|
|
||||||
not _has_tensor_in_arg(func)):
|
|
||||||
return False
|
return False
|
||||||
else:
|
else:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# remove methods which are not init functions
|
# remove methods which are not init functions
|
||||||
nn_init_methods = list(filter(_is_init_method, nn_init_methods))
|
nn_init_methods = list(filter(_is_init_method, nn_init_methods))
|
||||||
return nn_init_methods
|
return nn_init_methods
|
||||||
|
|
||||||
def _wrap_module_init(self, func):
|
def _wrap_module_init(self, func):
|
||||||
"""
|
"""
|
||||||
This method wraps the calls to the `__init__` of ``torch.nn.Module`` and replaces
|
This method wraps the calls to the `__init__` of ``torch.nn.Module`` and replaces
|
||||||
@ -99,38 +101,47 @@ class LazyInitContext():
|
|||||||
"""
|
"""
|
||||||
has_device = 'device' in inspect.signature(func).parameters
|
has_device = 'device' in inspect.signature(func).parameters
|
||||||
|
|
||||||
def layer_lazy_init(*args, **kwargs):
|
def layer_lazy_init(module, *args, **kwargs):
|
||||||
|
self._intercepted_init_func_cache.append(dict(func=func, module=module, args=args, kwargs=kwargs))
|
||||||
if has_device:
|
if has_device:
|
||||||
kwargs['device'] = 'meta'
|
kwargs['device'] = 'meta'
|
||||||
func(*args, **kwargs)
|
func(module, *args, **kwargs)
|
||||||
|
if not has_device:
|
||||||
|
module.to('meta')
|
||||||
|
|
||||||
return layer_lazy_init
|
return layer_lazy_init
|
||||||
|
|
||||||
def _get_tmp_origin_func_ref(self, name):
|
def _get_tmp_origin_func_ref(self, name):
|
||||||
"""
|
"""
|
||||||
Generate a function name for consistency during caching and retrieving.
|
Generate a function name for consistency during caching and retrieving.
|
||||||
"""
|
"""
|
||||||
return f'_orig_{name}'
|
return f'_orig_{name}'
|
||||||
|
|
||||||
def _patch_nn_init_funcs(self):
|
def _patch_nn_init_funcs(self):
|
||||||
# patch nn.init functions
|
# patch nn.init functions
|
||||||
for name, func in self._nn_init_methods:
|
for name, func in self._nn_init_methods:
|
||||||
setattr(torch.nn.init, name, self._cache_func(func))
|
setattr(torch.nn.init, name, self._cache_func(func))
|
||||||
|
|
||||||
def _unpatch_nn_init_funcs(self):
|
def _unpatch_nn_init_funcs(self):
|
||||||
# unpatch nn.init functions
|
# unpatch nn.init functions
|
||||||
for name, func in self._nn_init_methods:
|
for name, func in self._nn_init_methods:
|
||||||
setattr(torch.nn.init, name, func)
|
setattr(torch.nn.init, name, func)
|
||||||
|
|
||||||
def _patch_submodule_init(self):
|
def _patch_submodule_init(self):
|
||||||
# patch classes __init__ methods
|
# patch classes __init__ methods
|
||||||
for sub_cls in self._torch_mod_cls.__subclasses__():
|
def _activate_wrap_init(cls):
|
||||||
sub_cls.__orig_init__ = sub_cls.__init__
|
cls.__orig_init__ = cls.__init__
|
||||||
sub_cls.__init__ = self._wrap_module_init(sub_cls.__init__)
|
cls.__init__ = self._wrap_module_init(cls.__init__)
|
||||||
|
|
||||||
|
substitute_init_recursively(self._torch_mod_cls, _activate_wrap_init)
|
||||||
|
|
||||||
def _unpatch_submodule_init(self):
|
def _unpatch_submodule_init(self):
|
||||||
for sub_cls in self._torch_mod_cls.__subclasses__():
|
|
||||||
sub_cls.__init__ = sub_cls.__orig_init__
|
def _recover_orig_init(cls):
|
||||||
|
cls.__init__ = cls.__orig_init__
|
||||||
|
|
||||||
|
substitute_init_recursively(self._torch_mod_cls, _recover_orig_init)
|
||||||
|
|
||||||
def _patch_torch_tensor_funcs(self):
|
def _patch_torch_tensor_funcs(self):
|
||||||
# patch tensor value-setting functions
|
# patch tensor value-setting functions
|
||||||
for func_name in self._torch_tensor_funcs:
|
for func_name in self._torch_tensor_funcs:
|
||||||
@ -138,24 +149,20 @@ class LazyInitContext():
|
|||||||
origin_func = getattr(torch.Tensor, func_name)
|
origin_func = getattr(torch.Tensor, func_name)
|
||||||
setattr(torch.Tensor, origin_func_name, origin_func)
|
setattr(torch.Tensor, origin_func_name, origin_func)
|
||||||
setattr(torch.Tensor, func_name, self._cache_func(origin_func))
|
setattr(torch.Tensor, func_name, self._cache_func(origin_func))
|
||||||
|
|
||||||
def _unpatch_torch_tensor_funcs(self):
|
def _unpatch_torch_tensor_funcs(self):
|
||||||
for func_name in self._torch_tensor_funcs:
|
for func_name in self._torch_tensor_funcs:
|
||||||
origin_func_name = self._get_tmp_origin_func_ref(func_name)
|
origin_func_name = self._get_tmp_origin_func_ref(func_name)
|
||||||
origin_func = getattr(torch.Tensor, origin_func_name)
|
origin_func = getattr(torch.Tensor, origin_func_name)
|
||||||
setattr(torch.Tensor, func_name, origin_func)
|
setattr(torch.Tensor, func_name, origin_func)
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
self._patch_nn_init_funcs()
|
|
||||||
self._patch_torch_tensor_funcs()
|
|
||||||
self._patch_submodule_init()
|
self._patch_submodule_init()
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def __exit__(self, *args, **kwargs):
|
def __exit__(self, *args, **kwargs):
|
||||||
self._unpatch_submodule_init()
|
self._unpatch_submodule_init()
|
||||||
self._unpatch_torch_tensor_funcs()
|
|
||||||
self._unpatch_nn_init_funcs()
|
|
||||||
|
|
||||||
def lazy_init_parameters(self, model: torch.nn.Module, device='cpu', call_back: Callable = None):
|
def lazy_init_parameters(self, model: torch.nn.Module, device='cpu', call_back: Callable = None):
|
||||||
"""
|
"""
|
||||||
Initialize the weights of the meta-tensor model.
|
Initialize the weights of the meta-tensor model.
|
||||||
@ -169,13 +176,15 @@ class LazyInitContext():
|
|||||||
param_id_to_name = dict()
|
param_id_to_name = dict()
|
||||||
for name, param in model.named_parameters():
|
for name, param in model.named_parameters():
|
||||||
param_id_to_name[id(param)] = name
|
param_id_to_name[id(param)] = name
|
||||||
|
for name, buffer in model.named_buffers():
|
||||||
|
param_id_to_name[id(buffer)] = name
|
||||||
|
|
||||||
def _replace_meta_param_with_real_param(meta_param):
|
def _replace_meta_param_with_real_param(meta_param):
|
||||||
tensor_id = id(meta_param)
|
tensor_id = id(meta_param)
|
||||||
param_full_name = param_id_to_name[tensor_id]
|
param_full_name = param_id_to_name[tensor_id]
|
||||||
real_param = torch.empty_like(meta_param, dtype=meta_param.dtype, device=device)
|
real_param = torch.empty_like(meta_param, dtype=meta_param.dtype, device=device)
|
||||||
real_param = ColoParameter(real_param, requires_grad=meta_param.requires_grad)
|
real_param = ColoParameter(real_param, requires_grad=meta_param.requires_grad)
|
||||||
|
|
||||||
if '.' in param_full_name:
|
if '.' in param_full_name:
|
||||||
submodule_name, param_name = param_full_name.rsplit('.', 1)
|
submodule_name, param_name = param_full_name.rsplit('.', 1)
|
||||||
submodule = model.get_submodule(submodule_name)
|
submodule = model.get_submodule(submodule_name)
|
||||||
@ -183,41 +192,43 @@ class LazyInitContext():
|
|||||||
submodule = model
|
submodule = model
|
||||||
param_name = param_full_name
|
param_name = param_full_name
|
||||||
setattr(submodule, param_name, real_param)
|
setattr(submodule, param_name, real_param)
|
||||||
|
|
||||||
# execute call_back function on the materailized tensor
|
# execute call_back function on the materailized tensor
|
||||||
# this can where sharding comes in
|
# this can where sharding comes in
|
||||||
if call_back:
|
if call_back:
|
||||||
call_back(real_param)
|
call_back(real_param)
|
||||||
return real_param
|
return real_param
|
||||||
|
|
||||||
|
|
||||||
# build modules
|
# build modules
|
||||||
for cache in self._intercepted_init_func_cache:
|
# visit the cache list in reverse order
|
||||||
|
for index in range(len(self._intercepted_init_func_cache)):
|
||||||
|
cache = self._intercepted_init_func_cache[len(self._intercepted_init_func_cache) - index - 1]
|
||||||
func = cache['func']
|
func = cache['func']
|
||||||
|
module = cache['module']
|
||||||
args = list(cache['args'])
|
args = list(cache['args'])
|
||||||
kwargs = cache['kwargs']
|
kwargs = cache['kwargs']
|
||||||
|
|
||||||
# check args for parameter replacement
|
# check args for parameter replacement
|
||||||
for idx, arg in enumerate(args):
|
for idx, arg in enumerate(args):
|
||||||
if torch.is_tensor(arg):
|
if torch.is_tensor(arg):
|
||||||
tensor_id = id(arg)
|
tensor_id = id(arg)
|
||||||
|
|
||||||
if tensor_id not in param_id_to_name:
|
if tensor_id not in param_id_to_name:
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
arg = _replace_meta_param_with_real_param(arg)
|
arg = _replace_meta_param_with_real_param(arg)
|
||||||
args[idx] = arg
|
args[idx] = arg
|
||||||
|
|
||||||
# check kwargs for parameter replacement
|
# check kwargs for parameter replacement
|
||||||
for arg_name, arg in enumerate(kwargs):
|
for arg_name, arg in enumerate(kwargs):
|
||||||
if torch.is_tensor(arg):
|
if torch.is_tensor(arg):
|
||||||
tensor_id = id(arg)
|
tensor_id = id(arg)
|
||||||
|
|
||||||
if tensor_id not in param_id_to_name:
|
if tensor_id not in param_id_to_name:
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
arg = _replace_meta_param_with_real_param(arg)
|
arg = _replace_meta_param_with_real_param(arg)
|
||||||
kwargs[arg_name] = arg
|
kwargs[arg_name] = arg
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
func(*args, **kwargs)
|
func(module, *args, **kwargs)
|
||||||
|
@ -3,9 +3,9 @@ import functools
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
def _substitute_init_recursively(cls, func):
|
def substitute_init_recursively(cls, func):
|
||||||
for subcls in cls.__subclasses__():
|
for subcls in cls.__subclasses__():
|
||||||
_substitute_init_recursively(subcls, func)
|
substitute_init_recursively(subcls, func)
|
||||||
func(subcls)
|
func(subcls)
|
||||||
|
|
||||||
|
|
||||||
@ -64,7 +64,7 @@ class InsertPostInitMethodToModuleSubClasses(object):
|
|||||||
|
|
||||||
# Replace .__init__() for all existing subclasses of torch.nn.Module
|
# Replace .__init__() for all existing subclasses of torch.nn.Module
|
||||||
# Excution self._post_init_method after the default init function.
|
# Excution self._post_init_method after the default init function.
|
||||||
_substitute_init_recursively(torch.nn.modules.module.Module, _enable_class)
|
substitute_init_recursively(torch.nn.modules.module.Module, _enable_class)
|
||||||
|
|
||||||
# holding on to the current __init__subclass__ for exit
|
# holding on to the current __init__subclass__ for exit
|
||||||
torch.nn.modules.module.Module._old_init_subclass = (torch.nn.modules.module.Module.__init_subclass__)
|
torch.nn.modules.module.Module._old_init_subclass = (torch.nn.modules.module.Module.__init_subclass__)
|
||||||
@ -87,7 +87,7 @@ class InsertPostInitMethodToModuleSubClasses(object):
|
|||||||
cls.__init__ = cls._old_init
|
cls.__init__ = cls._old_init
|
||||||
|
|
||||||
# Replace .__init__() for all existing subclasses of torch.nn.Module
|
# Replace .__init__() for all existing subclasses of torch.nn.Module
|
||||||
_substitute_init_recursively(torch.nn.modules.module.Module, _disable_class)
|
substitute_init_recursively(torch.nn.modules.module.Module, _disable_class)
|
||||||
|
|
||||||
# Replace .__init__() for future subclasses of torch.nn.Module
|
# Replace .__init__() for future subclasses of torch.nn.Module
|
||||||
torch.nn.modules.module.Module.__init_subclass__ = (torch.nn.modules.module.Module._old_init_subclass)
|
torch.nn.modules.module.Module.__init_subclass__ = (torch.nn.modules.module.Module._old_init_subclass)
|
||||||
|
@ -1,23 +1,22 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
|
||||||
from colossalai.utils.model.lazy_init_context import LazyInitContext
|
from colossalai.utils.model.lazy_init_context import LazyInitContext
|
||||||
|
from torchvision.models import resnet34
|
||||||
|
|
||||||
def test_lazy_init_ctx():
|
|
||||||
|
|
||||||
with LazyInitContext() as ctx:
|
def test_lazy_init():
|
||||||
model = nn.Linear(10, 10)
|
ctx = LazyInitContext()
|
||||||
model.weight.zero_()
|
with ctx:
|
||||||
|
model = resnet34(num_classes=10)
|
||||||
# make sure the weight is a meta tensor
|
for param in model.parameters():
|
||||||
assert model.weight.is_meta
|
assert param.is_meta
|
||||||
|
for buffer in model.buffers():
|
||||||
# initialize weights
|
assert buffer.is_meta
|
||||||
ctx.lazy_init_parameters(model)
|
ctx.lazy_init_parameters(model)
|
||||||
|
for param in model.parameters():
|
||||||
# make sure the weight is not a meta tensor
|
assert not param.is_meta
|
||||||
# and initialized correctly
|
for buffer in model.buffers():
|
||||||
assert not model.weight.is_meta and torch.all(model.weight == 0)
|
assert not buffer.is_meta
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test_lazy_init_ctx()
|
test_lazy_init()
|
||||||
|
Loading…
Reference in New Issue
Block a user