mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-08 19:38:05 +00:00
[lazy] refactor lazy init (#3891)
* [lazy] remove old lazy init * [lazy] refactor lazy init folder structure * [lazy] fix lazy tensor deepcopy * [test] update lazy init test
This commit is contained in:
parent
70c8cdecf4
commit
dbb32692d2
6
colossalai/lazy/__init__.py
Normal file
6
colossalai/lazy/__init__.py
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
from .lazy_init import LazyInitContext, LazyTensor
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'LazyInitContext',
|
||||||
|
'LazyTensor',
|
||||||
|
]
|
@ -350,6 +350,13 @@ class LazyTensor(torch.Tensor):
|
|||||||
copied.requires_grad_()
|
copied.requires_grad_()
|
||||||
return copied
|
return copied
|
||||||
|
|
||||||
|
if self._materialized_data is not None:
|
||||||
|
# self is early materialized
|
||||||
|
copied = self._materialized_data.detach().clone()
|
||||||
|
if self.requires_grad:
|
||||||
|
copied.requires_grad_()
|
||||||
|
target = LazyTensor(lambda: None, concrete_data=copied)
|
||||||
|
else:
|
||||||
target = LazyTensor(factory_fn, meta_data=self._meta_data)
|
target = LazyTensor(factory_fn, meta_data=self._meta_data)
|
||||||
|
|
||||||
memo[id(self)] = target
|
memo[id(self)] = target
|
@ -1,242 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
# coding: utf-8
|
|
||||||
|
|
||||||
import inspect
|
|
||||||
import types
|
|
||||||
from typing import Callable, List
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
from colossalai.tensor import ColoParameter, ColoTensor
|
|
||||||
from colossalai.utils.model.utils import substitute_init_recursively
|
|
||||||
|
|
||||||
|
|
||||||
class LazyInitContext():
|
|
||||||
"""
|
|
||||||
A context to allow for lazy weight initialization of PyTorch modules. It intercepts the tensor
|
|
||||||
initialization functions for lazy initialization
|
|
||||||
|
|
||||||
Note:
|
|
||||||
This API is only experimental and subject to future changes.
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
with LazyInitContext() as ctx:
|
|
||||||
model = nn.Linear(10, 10)
|
|
||||||
model.weight.zero_()
|
|
||||||
|
|
||||||
# make sure the weight is a meta tensor
|
|
||||||
assert model.weight.is_meta
|
|
||||||
|
|
||||||
# initialize weights
|
|
||||||
ctx.lazy_init_parameters(model)
|
|
||||||
|
|
||||||
# make sure the weight is not a meta tensor
|
|
||||||
# and initialized correctly
|
|
||||||
assert not model.weight.is_meta and torch.all(model.weight == 0)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
to_meta (bool): optional, whether to initialize the model with meta tensors, default is True. This
|
|
||||||
argument exists for now because some corner cases such as self.weight = torch.zeros(...) cannot be captured yet.
|
|
||||||
extra_torch_tensor_func (List[str]): extra torch tensor functions related
|
|
||||||
to value setting, such as `zero_` and `triu_`. `zero_` is pre-added by default.
|
|
||||||
"""
|
|
||||||
|
|
||||||
tensor_set_value_func = ['zero_', 'fill_']
|
|
||||||
|
|
||||||
def __init__(self, to_meta: bool = True, extra_torch_tensor_func: List[str] = None):
|
|
||||||
# TODO: hijack the torch constructor functions as well
|
|
||||||
self._to_meta = to_meta
|
|
||||||
self._intercepted_nn_init_func_cache = {}
|
|
||||||
self._nn_init_methods = self._get_nn_init_methods()
|
|
||||||
self._torch_mod_cls = torch.nn.modules.module.Module
|
|
||||||
|
|
||||||
if extra_torch_tensor_func:
|
|
||||||
# use tuple to remove duplicates
|
|
||||||
self._torch_tensor_funcs = tuple(self.tensor_set_value_func + extra_torch_tensor_func)
|
|
||||||
else:
|
|
||||||
self._torch_tensor_funcs = self.tensor_set_value_func
|
|
||||||
|
|
||||||
@property
|
|
||||||
def to_meta(self):
|
|
||||||
return self._to_meta
|
|
||||||
|
|
||||||
def _cache_init_func(self, func):
|
|
||||||
"""
|
|
||||||
This method wraps the ``torch.nn.init`` method and torch tensor value-setting functions
|
|
||||||
so that the function call is cached instead of being executed.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def wrapped_init_func(tensor, *args, **kwargs):
|
|
||||||
if tensor not in self._intercepted_nn_init_func_cache:
|
|
||||||
self._intercepted_nn_init_func_cache[tensor] = []
|
|
||||||
self._intercepted_nn_init_func_cache[tensor].append((func, args, kwargs))
|
|
||||||
|
|
||||||
return wrapped_init_func
|
|
||||||
|
|
||||||
def _get_nn_init_methods(self):
|
|
||||||
"""
|
|
||||||
This method looks for all available functions in the ``torch.nn.init``
|
|
||||||
module.
|
|
||||||
"""
|
|
||||||
nn_init_method_names = dir(torch.nn.init)
|
|
||||||
nn_init_methods = []
|
|
||||||
|
|
||||||
# look for all methods in ``torch.nn.init`` module
|
|
||||||
for name in nn_init_method_names:
|
|
||||||
nn_init_methods.append((name, getattr(torch.nn.init, name)))
|
|
||||||
|
|
||||||
def _is_init_method(item):
|
|
||||||
name, func = item
|
|
||||||
|
|
||||||
if (not isinstance(func, types.FunctionType) or name.startswith('_') or not name.endswith('_')):
|
|
||||||
return False
|
|
||||||
else:
|
|
||||||
return True
|
|
||||||
|
|
||||||
# remove methods which are not init functions
|
|
||||||
nn_init_methods = list(filter(_is_init_method, nn_init_methods))
|
|
||||||
return nn_init_methods
|
|
||||||
|
|
||||||
def _wrap_module_init(self, func):
|
|
||||||
"""
|
|
||||||
This method wraps the calls to the `__init__` of ``torch.nn.Module`` and replaces
|
|
||||||
the argument device with value 'meta' so that all modules are created as meta tensors.
|
|
||||||
"""
|
|
||||||
has_device = 'device' in inspect.signature(func).parameters
|
|
||||||
|
|
||||||
def layer_lazy_init(module, *args, **kwargs):
|
|
||||||
# if this module contains device argument
|
|
||||||
# we set it to meta to initialize as meta backend
|
|
||||||
if has_device:
|
|
||||||
kwargs['device'] = 'meta'
|
|
||||||
func(module, *args, **kwargs)
|
|
||||||
|
|
||||||
# if device is not found, we intialize it and convert to meta
|
|
||||||
if not has_device:
|
|
||||||
module.to('meta')
|
|
||||||
|
|
||||||
return layer_lazy_init
|
|
||||||
|
|
||||||
def _get_tmp_origin_func_ref(self, name):
|
|
||||||
"""
|
|
||||||
Generate a function name for consistency during caching and retrieving.
|
|
||||||
"""
|
|
||||||
return f'_orig_{name}'
|
|
||||||
|
|
||||||
def _patch_nn_init_funcs(self):
|
|
||||||
# patch nn.init functions
|
|
||||||
for name, func in self._nn_init_methods:
|
|
||||||
setattr(torch.nn.init, name, self._cache_init_func(func))
|
|
||||||
|
|
||||||
def _unpatch_nn_init_funcs(self):
|
|
||||||
# unpatch nn.init functions
|
|
||||||
for name, func in self._nn_init_methods:
|
|
||||||
setattr(torch.nn.init, name, func)
|
|
||||||
|
|
||||||
def _patch_submodule_init(self):
|
|
||||||
# patch classes __init__ methods
|
|
||||||
def _activate_wrap_init(cls):
|
|
||||||
cls.__orig_init__ = cls.__init__
|
|
||||||
cls.__init__ = self._wrap_module_init(cls.__init__)
|
|
||||||
|
|
||||||
substitute_init_recursively(self._torch_mod_cls, _activate_wrap_init, set())
|
|
||||||
|
|
||||||
def _unpatch_submodule_init(self):
|
|
||||||
|
|
||||||
def _recover_orig_init(cls):
|
|
||||||
cls.__init__ = cls.__orig_init__
|
|
||||||
|
|
||||||
substitute_init_recursively(self._torch_mod_cls, _recover_orig_init, set())
|
|
||||||
|
|
||||||
def _patch_torch_tensor_funcs(self):
|
|
||||||
# patch tensor value-setting functions
|
|
||||||
for func_name in self._torch_tensor_funcs:
|
|
||||||
origin_func_name = self._get_tmp_origin_func_ref(func_name)
|
|
||||||
origin_func = getattr(torch.Tensor, func_name)
|
|
||||||
setattr(torch.Tensor, origin_func_name, origin_func)
|
|
||||||
setattr(torch.Tensor, func_name, self._cache_init_func(origin_func))
|
|
||||||
|
|
||||||
def _unpatch_torch_tensor_funcs(self):
|
|
||||||
for func_name in self._torch_tensor_funcs:
|
|
||||||
origin_func_name = self._get_tmp_origin_func_ref(func_name)
|
|
||||||
origin_func = getattr(torch.Tensor, origin_func_name)
|
|
||||||
setattr(torch.Tensor, func_name, origin_func)
|
|
||||||
|
|
||||||
def __enter__(self):
|
|
||||||
self._patch_torch_tensor_funcs()
|
|
||||||
self._patch_nn_init_funcs()
|
|
||||||
|
|
||||||
if self._to_meta:
|
|
||||||
self._patch_submodule_init()
|
|
||||||
return self
|
|
||||||
|
|
||||||
def __exit__(self, *args, **kwargs):
|
|
||||||
if self._to_meta:
|
|
||||||
self._unpatch_submodule_init()
|
|
||||||
self._unpatch_nn_init_funcs()
|
|
||||||
self._unpatch_torch_tensor_funcs()
|
|
||||||
|
|
||||||
def lazy_init_parameters(self, model: torch.nn.Module, device='cpu'):
|
|
||||||
"""
|
|
||||||
Initialize the weights of the meta-tensor model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model (`torch.nn.Module`): the model instantiated under the context.
|
|
||||||
device (str): the device on which weights are initialized
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
def _init_recursively(module: nn.Module):
|
|
||||||
# recursively initialize the module
|
|
||||||
for mod in module.children():
|
|
||||||
_init_recursively(mod)
|
|
||||||
|
|
||||||
# initialize and shard tensors directly attached to the current module
|
|
||||||
for name, param in module.named_parameters(recurse=False):
|
|
||||||
_init_and_shard(module, name, param)
|
|
||||||
|
|
||||||
for name, buf in module.named_buffers(recurse=False):
|
|
||||||
_init_and_shard(module, name, buf)
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def _init_and_shard(module, name, tensor):
|
|
||||||
# check whether the tensor is a buffer or parameter
|
|
||||||
is_param = isinstance(tensor, nn.parameter.Parameter)
|
|
||||||
|
|
||||||
# get sharding spec
|
|
||||||
dist_spec = getattr(tensor, 'dist_spec', None)
|
|
||||||
pg = getattr(tensor, 'pg', None)
|
|
||||||
comp_spec = getattr(tensor, 'comp_spec', None)
|
|
||||||
|
|
||||||
# convert the tensor from meta to materialized one
|
|
||||||
if tensor.is_meta:
|
|
||||||
materialized_tensor = torch.empty_like(tensor, device=device)
|
|
||||||
# if this tensor is a meta tensor, it must have an init function
|
|
||||||
assert tensor in self._intercepted_nn_init_func_cache
|
|
||||||
else:
|
|
||||||
materialized_tensor = tensor
|
|
||||||
|
|
||||||
# apply init function
|
|
||||||
if tensor in self._intercepted_nn_init_func_cache:
|
|
||||||
init_func, args, kwargs = self._intercepted_nn_init_func_cache[tensor][-1]
|
|
||||||
init_func(materialized_tensor, *args, **kwargs)
|
|
||||||
|
|
||||||
# convert it to ColoTensor or ColoParameter
|
|
||||||
if is_param:
|
|
||||||
tensor = ColoParameter.from_torch_tensor(materialized_tensor, requires_grad=tensor.requires_grad)
|
|
||||||
else:
|
|
||||||
tensor = ColoTensor.from_torch_tensor(materialized_tensor)
|
|
||||||
|
|
||||||
# override the original tensor
|
|
||||||
with torch.no_grad():
|
|
||||||
setattr(module, name, tensor)
|
|
||||||
|
|
||||||
# apply sharding
|
|
||||||
if dist_spec:
|
|
||||||
tensor.process_group = pg
|
|
||||||
tensor.set_tensor_spec(dist_spec, comp_spec)
|
|
||||||
|
|
||||||
_init_recursively(model)
|
|
||||||
|
|
||||||
return model
|
|
@ -2,13 +2,14 @@ import itertools
|
|||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Dict, Iterator, List, Optional, Union, Tuple, Set
|
from typing import Dict, Iterator, List, Optional, Set, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from colossalai.checkpoint_io.utils import calculate_tensor_size
|
from colossalai.checkpoint_io.utils import calculate_tensor_size
|
||||||
|
from colossalai.lazy import LazyTensor
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
from colossalai.nn.parallel.data_parallel import ColoDDP, _cast_float, free_storage
|
from colossalai.nn.parallel.data_parallel import ColoDDP, _cast_float, free_storage
|
||||||
from colossalai.tensor import ProcessGroup as ColoProcessGroup
|
from colossalai.tensor import ProcessGroup as ColoProcessGroup
|
||||||
@ -16,7 +17,6 @@ from colossalai.tensor import ReplicaSpec
|
|||||||
from colossalai.tensor.colo_parameter import ColoParameter, ColoTensor, ColoTensorSpec
|
from colossalai.tensor.colo_parameter import ColoParameter, ColoTensor, ColoTensorSpec
|
||||||
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
|
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
|
||||||
from colossalai.utils import get_current_device, is_ddp_ignored
|
from colossalai.utils import get_current_device, is_ddp_ignored
|
||||||
from colossalai.utils.model.experimental import LazyTensor
|
|
||||||
|
|
||||||
from .chunk import Chunk, ChunkManager, TensorState, init_chunk_manager
|
from .chunk import Chunk, ChunkManager, TensorState, init_chunk_manager
|
||||||
from .gemini_hook import GeminiZeROHook
|
from .gemini_hook import GeminiZeROHook
|
||||||
@ -96,11 +96,14 @@ class ZeroDDP(ColoDDP):
|
|||||||
param_name = m_name + '.' + p_name if m_name else p_name
|
param_name = m_name + '.' + p_name if m_name else p_name
|
||||||
self.name2param[param_name] = p_var
|
self.name2param[param_name] = p_var
|
||||||
super().__init__(module, process_group=ColoProcessGroup())
|
super().__init__(module, process_group=ColoProcessGroup())
|
||||||
self._non_persistent_buffers_set=self._get_non_persistent_buffers_set(module)
|
self._non_persistent_buffers_set = self._get_non_persistent_buffers_set(module)
|
||||||
self._cast_buffers()
|
self._cast_buffers()
|
||||||
|
|
||||||
def _get_non_persistent_buffers_set(self, module, memo: Optional[Set[nn.Module]] = None, prefix: str = '', remove_duplicate: bool = True):
|
def _get_non_persistent_buffers_set(self,
|
||||||
|
module,
|
||||||
|
memo: Optional[Set[nn.Module]] = None,
|
||||||
|
prefix: str = '',
|
||||||
|
remove_duplicate: bool = True):
|
||||||
r"""
|
r"""
|
||||||
Args:
|
Args:
|
||||||
memo: a memo to store the set of modules already added to the result
|
memo: a memo to store the set of modules already added to the result
|
||||||
@ -115,16 +118,17 @@ class ZeroDDP(ColoDDP):
|
|||||||
if module not in memo:
|
if module not in memo:
|
||||||
if remove_duplicate:
|
if remove_duplicate:
|
||||||
memo.add(module)
|
memo.add(module)
|
||||||
self_non_persistent_set = set(map(lambda key: prefix + ('.' if prefix else '') + key, module._non_persistent_buffers_set))
|
self_non_persistent_set = set(
|
||||||
|
map(lambda key: prefix + ('.' if prefix else '') + key, module._non_persistent_buffers_set))
|
||||||
for name, sub_module in module._modules.items():
|
for name, sub_module in module._modules.items():
|
||||||
if sub_module is None:
|
if sub_module is None:
|
||||||
continue
|
continue
|
||||||
submodule_prefix = prefix + ('.' if prefix else '') + name
|
submodule_prefix = prefix + ('.' if prefix else '') + name
|
||||||
child_non_persistent_set = self._get_non_persistent_buffers_set(sub_module, memo, submodule_prefix, remove_duplicate)
|
child_non_persistent_set = self._get_non_persistent_buffers_set(sub_module, memo, submodule_prefix,
|
||||||
|
remove_duplicate)
|
||||||
self_non_persistent_set = set.union(self_non_persistent_set, child_non_persistent_set)
|
self_non_persistent_set = set.union(self_non_persistent_set, child_non_persistent_set)
|
||||||
return self_non_persistent_set
|
return self_non_persistent_set
|
||||||
|
|
||||||
|
|
||||||
def _post_forward(self):
|
def _post_forward(self):
|
||||||
"""This function is only triggered for inference.
|
"""This function is only triggered for inference.
|
||||||
"""
|
"""
|
||||||
|
@ -8,10 +8,10 @@ import colossalai
|
|||||||
from colossalai.booster import Booster
|
from colossalai.booster import Booster
|
||||||
from colossalai.booster.plugin import GeminiPlugin
|
from colossalai.booster.plugin import GeminiPlugin
|
||||||
from colossalai.fx import is_compatible_with_meta
|
from colossalai.fx import is_compatible_with_meta
|
||||||
|
from colossalai.lazy.lazy_init import LazyInitContext
|
||||||
from colossalai.nn.optimizer import HybridAdam
|
from colossalai.nn.optimizer import HybridAdam
|
||||||
from colossalai.tensor.colo_parameter import ColoParameter
|
from colossalai.tensor.colo_parameter import ColoParameter
|
||||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||||
from colossalai.utils.model.experimental import LazyInitContext
|
|
||||||
from colossalai.zero import ColoInitContext
|
from colossalai.zero import ColoInitContext
|
||||||
from tests.kit.model_zoo import model_zoo
|
from tests.kit.model_zoo import model_zoo
|
||||||
|
|
||||||
|
@ -1,12 +1,13 @@
|
|||||||
import random
|
import random
|
||||||
|
from copy import deepcopy
|
||||||
from typing import Any, Callable, Optional, Tuple
|
from typing import Any, Callable, Optional, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from packaging import version
|
from packaging import version
|
||||||
|
|
||||||
|
from colossalai.lazy.lazy_init import LazyInitContext, LazyTensor, _MyTensor
|
||||||
from colossalai.tensor.d_tensor.layout_converter import to_global
|
from colossalai.tensor.d_tensor.layout_converter import to_global
|
||||||
from colossalai.utils.model.experimental import LazyInitContext, LazyTensor, _MyTensor
|
|
||||||
from tests.kit.model_zoo.registry import ModelAttribute
|
from tests.kit.model_zoo.registry import ModelAttribute
|
||||||
|
|
||||||
SUPPORT_LAZY = version.parse(torch.__version__) >= version.parse('1.12.0')
|
SUPPORT_LAZY = version.parse(torch.__version__) >= version.parse('1.12.0')
|
||||||
@ -31,6 +32,9 @@ def assert_model_equal(m1: torch.nn.Module, m2: torch.nn.Module) -> None:
|
|||||||
assert n1 == n2
|
assert n1 == n2
|
||||||
assert torch.equal(t1, t2), f'{n1} {t1} vs {t2}'
|
assert torch.equal(t1, t2), f'{n1} {t1} vs {t2}'
|
||||||
|
|
||||||
|
for p1, p2 in zip(m1.parameters(), m2.parameters()):
|
||||||
|
assert p1.requires_grad == p2.requires_grad
|
||||||
|
|
||||||
|
|
||||||
def assert_forward_equal(m1: torch.nn.Module, m2: torch.nn.Module, data_gen_fn: Callable[[], dict],
|
def assert_forward_equal(m1: torch.nn.Module, m2: torch.nn.Module, data_gen_fn: Callable[[], dict],
|
||||||
output_transform_fn: Callable[[Any], dict]) -> None:
|
output_transform_fn: Callable[[Any], dict]) -> None:
|
||||||
@ -65,10 +69,14 @@ def check_lazy_init(entry: TestingEntry, seed: int = 42, verbose: bool = False,
|
|||||||
ctx = LazyInitContext()
|
ctx = LazyInitContext()
|
||||||
with ctx:
|
with ctx:
|
||||||
deferred_model = model_fn()
|
deferred_model = model_fn()
|
||||||
|
copied_deferred_model = deepcopy(deferred_model)
|
||||||
deferred_model = ctx.materialize(deferred_model, verbose=verbose)
|
deferred_model = ctx.materialize(deferred_model, verbose=verbose)
|
||||||
|
copied_deferred_model = ctx.materialize(copied_deferred_model, verbose=verbose)
|
||||||
assert_model_equal(model, deferred_model)
|
assert_model_equal(model, deferred_model)
|
||||||
|
assert_model_equal(deferred_model, copied_deferred_model)
|
||||||
if check_forward:
|
if check_forward:
|
||||||
assert_forward_equal(model, deferred_model, data_gen_fn, output_transform_fn)
|
assert_forward_equal(model, deferred_model, data_gen_fn, output_transform_fn)
|
||||||
|
assert_forward_equal(deferred_model, copied_deferred_model, data_gen_fn, output_transform_fn)
|
||||||
if verbose:
|
if verbose:
|
||||||
print(f'{model.__class__.__name__} pass')
|
print(f'{model.__class__.__name__} pass')
|
||||||
|
|
@ -12,7 +12,7 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
|||||||
from colossalai.utils.common import print_rank_0
|
from colossalai.utils.common import print_rank_0
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from colossalai.utils.model.experimental import LazyInitContext, LazyTensor, _MyTensor
|
from colossalai.lazy.lazy_init import LazyInitContext, LazyTensor, _MyTensor
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
from lazy_init_utils import SUPPORT_LAZY, assert_dist_model_equal, set_seed
|
from lazy_init_utils import SUPPORT_LAZY, assert_dist_model_equal, set_seed
|
@ -1,51 +0,0 @@
|
|||||||
import torch
|
|
||||||
from colossalai.utils.model.lazy_init_context import LazyInitContext
|
|
||||||
from torchvision.models import resnet34
|
|
||||||
import random
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
MANUAL_SEED = 0
|
|
||||||
random.seed(MANUAL_SEED)
|
|
||||||
np.random.seed(MANUAL_SEED)
|
|
||||||
torch.manual_seed(MANUAL_SEED)
|
|
||||||
|
|
||||||
|
|
||||||
def test_lazy_init_with_meta():
|
|
||||||
ctx = LazyInitContext(to_meta=True)
|
|
||||||
with ctx:
|
|
||||||
model = resnet34(num_classes=10)
|
|
||||||
|
|
||||||
for param in model.parameters():
|
|
||||||
assert param.is_meta
|
|
||||||
for buffer in model.buffers():
|
|
||||||
assert buffer.is_meta
|
|
||||||
|
|
||||||
ctx.lazy_init_parameters(model)
|
|
||||||
|
|
||||||
for name, param in model.named_parameters():
|
|
||||||
assert not param.is_meta, name
|
|
||||||
|
|
||||||
for buffer in model.buffers():
|
|
||||||
assert not buffer.is_meta
|
|
||||||
|
|
||||||
|
|
||||||
def test_lazy_init_without_meta():
|
|
||||||
ctx = LazyInitContext(to_meta=False)
|
|
||||||
with ctx:
|
|
||||||
model = resnet34(num_classes=10)
|
|
||||||
|
|
||||||
for param in model.parameters():
|
|
||||||
assert not param.is_meta
|
|
||||||
for buffer in model.buffers():
|
|
||||||
assert not buffer.is_meta
|
|
||||||
|
|
||||||
conv1_weight_before_init = model.conv1.weight.clone()
|
|
||||||
ctx.lazy_init_parameters(model)
|
|
||||||
conv1_weight_after_init = model.conv1.weight.clone()
|
|
||||||
|
|
||||||
assert not torch.allclose(conv1_weight_after_init, conv1_weight_before_init)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
test_lazy_init_with_meta()
|
|
||||||
test_lazy_init_without_meta()
|
|
Loading…
Reference in New Issue
Block a user