mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-12 20:54:35 +00:00
[lazyinit] add correctness verification (#3147)
* [lazyinit] fix shared module * [tests] add lazy init test utils * [tests] add torchvision for lazy init * [lazyinit] fix pre op fn * [lazyinit] handle legacy constructor * [tests] refactor lazy init test models * [tests] refactor lazy init test utils * [lazyinit] fix ops don't support meta * [tests] lazy init test timm models * [lazyinit] fix set data * [lazyinit] handle apex layers * [tests] lazy init test transformers models * [tests] lazy init test torchaudio models * [lazyinit] fix import path * [tests] lazy init test torchrec models * [tests] update torch version in CI * [tests] revert torch version in CI * [tests] skip lazy init test
This commit is contained in:
@@ -1,17 +1,16 @@
|
||||
from typing import Callable, Optional, Union
|
||||
from typing import Callable, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
from colossalai.fx.profiler import MetaTensor
|
||||
from colossalai.fx.profiler.tensor import MetaTensor
|
||||
|
||||
# reference: https://pytorch.org/cppdocs/notes/tensor_creation.html
|
||||
_TorchFactoryMethod = [
|
||||
_NORMAL_FACTORY = [
|
||||
"arange",
|
||||
"empty",
|
||||
"eye",
|
||||
"full",
|
||||
"linspace",
|
||||
"logspace",
|
||||
@@ -24,17 +23,39 @@ _TorchFactoryMethod = [
|
||||
"tensor",
|
||||
]
|
||||
|
||||
# factory function that does not support meta tensor backend
|
||||
_NO_META_FACTORY = [
|
||||
"eye",
|
||||
]
|
||||
|
||||
_EARLY_MATERIALIZED_OPS = ['__getitem__', 'split']
|
||||
|
||||
_LEGACY_TENSOR_CONSTRUCTOR = {
|
||||
'FloatTensor': torch.float,
|
||||
'DoubleTensor': torch.double,
|
||||
'HalfTensor': torch.half,
|
||||
'BFloat16Tensor': torch.bfloat16,
|
||||
'ByteTensor': torch.uint8,
|
||||
'CharTensor': torch.int8,
|
||||
'ShortTensor': torch.short,
|
||||
'IntTensor': torch.int,
|
||||
'LongTensor': torch.long,
|
||||
'BoolTensor': torch.bool,
|
||||
}
|
||||
|
||||
|
||||
class _MyTensor(Tensor):
|
||||
"""This class is only for correctness verification.
|
||||
"""
|
||||
_pre_op_fn: Callable[['LazyTensor'], None] = lambda *args: None
|
||||
|
||||
def __new__(cls, func, *args, dtype=None, device=None, **kwargs) -> '_MyTensor':
|
||||
def __new__(cls, func, *args, concrete_data=None, **kwargs) -> '_MyTensor':
|
||||
cls._pre_op_fn()
|
||||
data = func(*args, dtype=dtype, device=device, **kwargs)
|
||||
if concrete_data is not None:
|
||||
# uniform api as LazyTensor
|
||||
data = concrete_data
|
||||
else:
|
||||
data = func(*args, **kwargs)
|
||||
return Tensor._make_subclass(cls, data, require_grad=data.requires_grad)
|
||||
|
||||
@classmethod
|
||||
@@ -66,11 +87,13 @@ class LazyTensor(torch.Tensor):
|
||||
>>> x.add_(1) # modifying origin tensor after cloning leads to wrong materialization
|
||||
>>> z = x.tolist()
|
||||
>>> x.zeros_() # modifying origin tensor after cloning tolist is not allowed
|
||||
>>> x.data = torch.rand(2, 3) # directly set data of a lazy tensor is not allowed
|
||||
>>> nn.utils.weight_norm(self.conv, name="weight", dim=2) # applying weight norm on a lazy tensor is not allowed
|
||||
|
||||
|
||||
2. Cases that ``LazyTensor`` becomes eager (early materialization).
|
||||
>>> b = a[:, 2:] # get a slice of a lazy tensor triggers early materialization
|
||||
>>> chunks = a.split(3) # this also triggers early materialization
|
||||
>>> x.data = torch.rand(2, 3) # directly setting data of a lazy tensor triggers early materialization
|
||||
|
||||
"""
|
||||
|
||||
@@ -79,12 +102,16 @@ class LazyTensor(torch.Tensor):
|
||||
_pre_op_fn: Callable[['LazyTensor'], None] = lambda *args: None
|
||||
|
||||
@staticmethod
|
||||
def __new__(cls, func, *args, meta_data=None, **kwargs):
|
||||
if meta_data is None:
|
||||
device = kwargs.get('device', 'cpu')
|
||||
elem = func(*args, **{**kwargs, 'device': 'meta'})
|
||||
meta_data = MetaTensor(elem, fake_device=device)
|
||||
elem = meta_data._tensor
|
||||
def __new__(cls, func, *args, meta_data=None, concrete_data=None, **kwargs):
|
||||
if concrete_data is not None:
|
||||
# some ops don't support meta backend and should have concrete data
|
||||
elem = concrete_data
|
||||
else:
|
||||
if meta_data is None:
|
||||
device = kwargs.get('device', 'cpu')
|
||||
elem = func(*args, **{**kwargs, 'device': 'meta'})
|
||||
meta_data = MetaTensor(elem, fake_device=device)
|
||||
elem = meta_data._tensor
|
||||
r = torch.Tensor._make_wrapper_subclass(cls,
|
||||
elem.size(),
|
||||
strides=elem.stride(),
|
||||
@@ -96,10 +123,10 @@ class LazyTensor(torch.Tensor):
|
||||
r._meta_data = meta_data
|
||||
return r
|
||||
|
||||
def __init__(self, func, *args, meta_data=None, **kwargs):
|
||||
def __init__(self, func, *args, meta_data=None, concrete_data=None, **kwargs):
|
||||
self._factory_method = (func, args, kwargs) # (func, args, kwargs)
|
||||
self._op_buffer = [] # (func, args, kwargs, replace)
|
||||
self._materialized_data: Optional[torch.Tensor] = None # materialized data
|
||||
self._materialized_data: Optional[torch.Tensor] = concrete_data # materialized data
|
||||
|
||||
def materialize(self) -> torch.Tensor:
|
||||
"""Materialize the ``LazyTensor`` to ``torch.Tensor``.
|
||||
@@ -212,7 +239,7 @@ class LazyTensor(torch.Tensor):
|
||||
if isinstance(x, LazyTensor):
|
||||
if x._materialized_data is not None:
|
||||
# for early materialized tensor, use its materialized data directly
|
||||
return x._materialized_data
|
||||
return x._materialized_data.data
|
||||
t = x if is_inplace else x.clone()
|
||||
t._op_buffer.append((func, args, kwargs))
|
||||
meta = x._meta_data.data
|
||||
@@ -232,13 +259,10 @@ class LazyTensor(torch.Tensor):
|
||||
return lazy_y
|
||||
elif type(y) is Tensor:
|
||||
# for early materialized tensor
|
||||
with torch._C.DisableTorchFunction():
|
||||
meta = MetaTensor(y.new_empty(y.shape, dtype=y.dtype, device='meta'), fake_device=y.device)
|
||||
lazy_y = LazyTensor(lambda: None, meta_data=meta)
|
||||
lazy_y._materialized_data = y
|
||||
return lazy_y
|
||||
return LazyTensor(lambda: None, concrete_data=y)
|
||||
return y
|
||||
|
||||
cls._pre_op_fn()
|
||||
o = func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs))
|
||||
if isinstance(o, (tuple, list)):
|
||||
return type(o)(wrap(y, i=i) for i, y in enumerate(o))
|
||||
@@ -266,7 +290,10 @@ class LazyTensor(torch.Tensor):
|
||||
|
||||
@data.setter
|
||||
def data(self, other: 'LazyTensor'):
|
||||
raise NotImplementedError
|
||||
if other is self:
|
||||
return
|
||||
# TODO(ver217): to avoid infinity recursion, do early materialization
|
||||
self._materialized_data = other._materialize_data()
|
||||
|
||||
def tolist(self) -> list:
|
||||
t = self.materialize()
|
||||
@@ -330,18 +357,61 @@ class LazyInitContext:
|
||||
|
||||
return wrapper, target
|
||||
|
||||
def wrap_legacy_constructor(target, dtype):
|
||||
# legacy constructor (e.g. torch.LongTensor())
|
||||
def wrapper(*args, **kwargs):
|
||||
if len(args) == 1 and isinstance(args[0], torch.Tensor):
|
||||
# (Tensor other)
|
||||
return args[0]
|
||||
elif len(args) == 1:
|
||||
# (object data, *, torch.device device)
|
||||
kwargs = {**kwargs, 'dtype': dtype}
|
||||
replaced, orig = self.overrides['tensor']
|
||||
return replaced(*args, **kwargs)
|
||||
elif _is_int_tuple(args):
|
||||
# (tuple of ints size, *, torch.device device)
|
||||
kwargs = {**kwargs, 'dtype': dtype}
|
||||
replaced, orig = self.overrides['empty']
|
||||
return replaced(*args, **kwargs)
|
||||
else:
|
||||
raise TypeError(
|
||||
f'new() received an invalid combination of arguments - got {tuple(type(x) for x in args)}, but expected one of:\n * (Tensor other)\n * (tuple of ints size, *, torch.device device)\n * (object data, *, torch.device device)'
|
||||
)
|
||||
|
||||
return wrapper, target
|
||||
|
||||
def wrap_no_meta_factory(target):
|
||||
# factory functions which don't support meta tensor backend
|
||||
def wrapper(*args, **kwargs):
|
||||
tensor = target(*args, **kwargs)
|
||||
return self.tensor_cls(lambda: None, concrete_data=tensor)
|
||||
|
||||
return wrapper, target
|
||||
|
||||
self.overrides = {
|
||||
target: wrap_factory_method(getattr(torch, target))
|
||||
for target in _TorchFactoryMethod
|
||||
for target in _NORMAL_FACTORY
|
||||
if callable(getattr(torch, target, None))
|
||||
}
|
||||
|
||||
self.overrides.update({
|
||||
target + '_like': wrap_factory_like_method(getattr(torch, target), getattr(torch, target + '_like'))
|
||||
for target in _TorchFactoryMethod
|
||||
for target in _NORMAL_FACTORY
|
||||
if callable(getattr(torch, target + '_like', None))
|
||||
})
|
||||
|
||||
self.overrides.update({
|
||||
target: wrap_legacy_constructor(getattr(torch, target), dtype)
|
||||
for target, dtype in _LEGACY_TENSOR_CONSTRUCTOR.items()
|
||||
if callable(getattr(torch, target, None))
|
||||
})
|
||||
|
||||
self.overrides.update({
|
||||
target: wrap_no_meta_factory(getattr(torch, target))
|
||||
for target in _NO_META_FACTORY
|
||||
if callable(getattr(torch, target, None))
|
||||
})
|
||||
|
||||
for name, (wrapper, orig) in self.overrides.items():
|
||||
setattr(torch, name, wrapper)
|
||||
|
||||
@@ -363,34 +433,65 @@ class LazyInitContext:
|
||||
param_lazy_cnt = 0
|
||||
buf_cnt = 0
|
||||
buf_lazy_cnt = 0
|
||||
non_lazy_numel = 0
|
||||
|
||||
# do post cleaning to handle shared parameter
|
||||
visited_lazy_tensors: List[LazyTensor] = []
|
||||
# handle shared module
|
||||
visited_modules = set()
|
||||
|
||||
@torch.no_grad()
|
||||
def init_recursively(module: nn.Module):
|
||||
nonlocal param_cnt, param_lazy_cnt, buf_cnt, buf_lazy_cnt
|
||||
nonlocal param_cnt, param_lazy_cnt, buf_cnt, buf_lazy_cnt, non_lazy_numel
|
||||
# recursively initialize the module
|
||||
for mod in module.children():
|
||||
init_recursively(mod)
|
||||
if id(mod) not in visited_modules:
|
||||
visited_modules.add(id(mod))
|
||||
init_recursively(mod)
|
||||
|
||||
# initialize tensors directly attached to the current module
|
||||
for name, param in module.named_parameters(recurse=False):
|
||||
if verbose:
|
||||
param_cnt += 1
|
||||
if param._materialized_data is None:
|
||||
if getattr(param, '_materialized_data', False) is None:
|
||||
# if no _materialized_data attr, the tensor is not lazy
|
||||
param_lazy_cnt += 1
|
||||
setattr(module, name, param.materialize())
|
||||
param.clean()
|
||||
else:
|
||||
non_lazy_numel += param.numel()
|
||||
if hasattr(param, 'materialize'):
|
||||
# TODO(ver217): apex layers cannot be captured
|
||||
visited_lazy_tensors.append(param)
|
||||
setattr(module, name, param.materialize())
|
||||
|
||||
for name, buf in module.named_buffers(recurse=False):
|
||||
if verbose:
|
||||
buf_cnt += 1
|
||||
if buf._materialized_data is None:
|
||||
if getattr(buf, "_materialized_data", False) is None:
|
||||
# if no _materialized_data attr, the tensor is not lazy
|
||||
buf_lazy_cnt += 1
|
||||
setattr(module, name, buf.materialize())
|
||||
buf.clean()
|
||||
else:
|
||||
non_lazy_numel += buf.numel()
|
||||
if hasattr(buf, 'materialize'):
|
||||
# TODO(ver217): apex layers cannot be captured
|
||||
visited_lazy_tensors.append(buf)
|
||||
setattr(module, name, buf.materialize())
|
||||
|
||||
init_recursively(module)
|
||||
|
||||
for t in visited_lazy_tensors:
|
||||
t.clean()
|
||||
|
||||
if verbose:
|
||||
print(f'Param lazy rate: {param_lazy_cnt}/{param_cnt}')
|
||||
print(f'Buffer lazy rate: {buf_lazy_cnt}/{buf_cnt}')
|
||||
print(f'Non-lazy numel: {non_lazy_numel} ({non_lazy_numel/1024**2:.3f} M)')
|
||||
return module
|
||||
|
||||
|
||||
def _is_int_tuple(args) -> bool:
|
||||
if not isinstance(args, tuple):
|
||||
return False
|
||||
for x in args:
|
||||
if not isinstance(x, int):
|
||||
return False
|
||||
return True
|
||||
|
Reference in New Issue
Block a user