From 3e05c07bb8921f2a8f9736b6f6673d4e9f1697d0 Mon Sep 17 00:00:00 2001 From: Hongxin Liu Date: Thu, 21 Sep 2023 16:30:23 +0800 Subject: [PATCH] [lazy] support torch 2.0 (#4763) * [lazy] support _like methods and clamp * [lazy] pass transformers models * [lazy] fix device move and requires grad * [lazy] fix requires grad and refactor api * [lazy] fix requires grad --- .isort.cfg | 1 + colossalai/lazy/construction.py | 87 ++++++++++++++ colossalai/lazy/lazy_init.py | 207 ++++++++++++++++++-------------- tests/test_lazy/test_models.py | 8 +- tests/test_lazy/test_ops.py | 64 ++++++++++ 5 files changed, 273 insertions(+), 94 deletions(-) create mode 100644 colossalai/lazy/construction.py create mode 100644 tests/test_lazy/test_ops.py diff --git a/.isort.cfg b/.isort.cfg index 4f881c8b3..ccbf575fd 100644 --- a/.isort.cfg +++ b/.isort.cfg @@ -4,3 +4,4 @@ multi_line_output=3 include_trailing_comma = true ignore_comments = true profile = black +honor_noqa = true diff --git a/colossalai/lazy/construction.py b/colossalai/lazy/construction.py new file mode 100644 index 000000000..6764eaf77 --- /dev/null +++ b/colossalai/lazy/construction.py @@ -0,0 +1,87 @@ +from contextlib import contextmanager +from typing import Callable, Dict, Tuple + +import torch + +__all__ = [ + "_LEGACY_TENSOR_CONSTRUCTOR", + "_NO_META_FACTORY", + "_NORMAL_FACTORY", + "ConstructorManager", +] + +# reference: https://pytorch.org/cppdocs/notes/tensor_creation.html +_NORMAL_FACTORY = [ + "arange", + "full", + "empty", + "linspace", + "logspace", + "ones", + "rand", + "randn", + "randint", + "randperm", + "zeros", + "tensor", +] + +# factory function that does not support meta tensor backend +_NO_META_FACTORY = [ + "eye", +] + +_LEGACY_TENSOR_CONSTRUCTOR = { + "FloatTensor": torch.float, + "DoubleTensor": torch.double, + "HalfTensor": torch.half, + "BFloat16Tensor": torch.bfloat16, + "ByteTensor": torch.uint8, + "CharTensor": torch.int8, + "ShortTensor": torch.short, + "IntTensor": torch.int, + "LongTensor": torch.long, + "BoolTensor": torch.bool, +} + + +class ConstructorManager: + # function name: (new, old) + overwrites: Dict[str, Tuple[Callable, Callable]] = {} + changed: bool = False + + @staticmethod + def apply(overwrites: Dict[Callable, Callable]): + ConstructorManager.overwrites.clear() + ConstructorManager.overwrites.update(overwrites) + ConstructorManager.redo() + + @staticmethod + def undo(): + assert ConstructorManager.changed, "No constructor change to undo" + for name, (new, old) in ConstructorManager.overwrites.items(): + setattr(torch, name, old) + ConstructorManager.changed = False + + @staticmethod + def redo(): + assert not ConstructorManager.changed, "Constructor already changed" + for name, (new, old) in ConstructorManager.overwrites.items(): + setattr(torch, name, new) + ConstructorManager.changed = True + + @staticmethod + @contextmanager + def disable(): + enabled = ConstructorManager.changed + if enabled: + ConstructorManager.undo() + yield + if enabled: + ConstructorManager.redo() + + @staticmethod + def clear(): + if ConstructorManager.changed: + ConstructorManager.undo() + ConstructorManager.overwrites.clear() diff --git a/colossalai/lazy/lazy_init.py b/colossalai/lazy/lazy_init.py index ebaf2e160..f29e997da 100644 --- a/colossalai/lazy/lazy_init.py +++ b/colossalai/lazy/lazy_init.py @@ -1,17 +1,18 @@ from types import MethodType -from typing import Callable, Dict, Optional, Union +from typing import Callable, Optional, Union import torch -import torch.distributed as dist import torch.nn as nn +from packaging import version from torch import Tensor from torch.nn import Parameter from torch.utils._pytree import tree_map -from colossalai._analyzer._subclasses import MetaTensor -from colossalai.device.device_mesh import DeviceMesh -from colossalai.tensor.d_tensor import distribute_tensor -from colossalai.tensor.d_tensor.sharding_spec import ShardingSpec +from colossalai.logging import get_dist_logger + +from .construction import ConstructorManager + +import colossalai._analyzer._subclasses._meta_registration # noqa # reference: https://pytorch.org/cppdocs/notes/tensor_creation.html _NORMAL_FACTORY = [ @@ -41,6 +42,9 @@ _EARLY_MATERIALIZED_OPS = ["__getitem__", "split"] # These ops cannot be unwrapped using .data _CHANGE_META_OPS = ["_cudnn_rnn_flatten_weight", "requires_grad_", "__get__", "__set__", "numel", "size", "dim"] +# These ops is not related to tensor value and should not be rerun +_NO_RERUN_OPS = ["__get__", "numel", "size", "dim"] + _LEGACY_TENSOR_CONSTRUCTOR = { "FloatTensor": torch.float, "DoubleTensor": torch.double, @@ -54,6 +58,20 @@ _LEGACY_TENSOR_CONSTRUCTOR = { "BoolTensor": torch.bool, } +# These ops have at least one lazy tensor argument and maybe a scalar argument +# scalar value should be converted to meta tensor +# this is a hack for torch 2.0 +_EXPAND_SCALAR_OPS = [ + "where", + "clamp", + "clamp_min", + "clamp_max", + "clamp_", + "clamp_min_", + "clamp_max_", +] +_old_tensor_factory = torch.tensor + _EMPTY_DATA = torch.empty(0) @@ -145,34 +163,48 @@ class LazyTensor(torch.Tensor): """ _repr = True - _meta_data: Optional[MetaTensor] = None # shape, dtype, device + _meta_data: Optional[torch.Tensor] = None # shape, dtype, device _pre_op_fn: Callable[["LazyTensor"], None] = lambda *args: None default_device: Optional[torch.device] = None + _device: torch.device # fake device of mate tensor @staticmethod def __new__(cls, func, *args, meta_data=None, concrete_data=None, **kwargs): + # tips for torch 2.0: + # torch 2.0 disables torch dispatch for subclass of tensor + # MetaTensor is cannot be used + # Now lazy tensor contains device injection and meta tensor if concrete_data is not None: # some ops don't support meta backend and should have concrete data elem = concrete_data else: if meta_data is None: - device = kwargs.get("device", "cpu") - elem = func(*args, **{**kwargs, "device": "meta"}) - meta_data = MetaTensor(elem, device=device) - elem = meta_data._tensor + with ConstructorManager.disable(): + # to disable create lazy tensor in inner ops, this is a hack for torch 2.0 + meta_data = func(*args, **{**kwargs, "device": "meta"}) + elem = meta_data # As a meta tensor cannot be modified __class__ to torch.Tensor, we should use an empty real tensor here r = torch.Tensor._make_subclass(cls, _EMPTY_DATA, require_grad=elem.requires_grad) r._meta_data = meta_data + return r def __init__(self, func, *args, meta_data=None, concrete_data=None, **kwargs): + self._device = torch.device(kwargs.get("device", None) or "cpu") if func.__name__ in _NORMAL_FACTORY: kwargs = {**kwargs, "device": LazyTensor.default_device} self._factory_method = (func, args, kwargs) # (func, args, kwargs) self._op_buffer = [] # (func, args, kwargs, replace) self._materialized_data: Optional[torch.Tensor] = concrete_data # materialized data + @property + def device(self) -> torch.device: + return self._materialized_data.device if self._materialized_data is not None else self._device + + def __repr__(self): + return f"LazyTensor(..., size={tuple(self.shape)}, device='{self.device}', dtype={self.dtype})" + def materialize(self) -> torch.Tensor: """Materialize the ``LazyTensor`` to ``torch.Tensor`` by modifying __class__ (inplace). @@ -183,20 +215,6 @@ class LazyTensor(torch.Tensor): self.clean() return _convert_cls(self, target) - def distribute(self, device_mesh: DeviceMesh, sharding_spec: ShardingSpec) -> torch.Tensor: - """Distribute the ``LazyTensor`` to ``torch.Tensor`` by modifying __class__ (inplace), according to the layout. - - Args: - layout (Layout): Distribution layout. - - Returns: - torch.Tensor: The distributed tensor (self). - """ - target = self._materialize_data() - self.clean() - local_tensor = distribute_tensor(target, device_mesh, sharding_spec) - return _convert_cls(self, local_tensor) - def clean(self) -> None: """Clean all stored operations, meta data and materialized data, which prevents memory leaking. This should be called after all tensors are materialized.""" delattr(self, "_factory_method") @@ -299,45 +317,80 @@ class LazyTensor(torch.Tensor): # for early materialized tensor, use its materialized data directly return x._materialized_data if is_change_meta_op else x._materialized_data.data t = x if is_inplace else x.clone() - t._op_buffer.append((func, args, kwargs)) + if func.__name__ not in _NO_RERUN_OPS: + t._op_buffer.append((func, args, kwargs)) meta = x._meta_data if is_change_meta_op else x._meta_data.data meta_to_lazy[meta] = t return meta + elif ( + version.parse(torch.__version__) >= version.parse("2.0.0") + and func.__name__ in _EXPAND_SCALAR_OPS + and not isinstance(x, torch.Tensor) + ): + return _old_tensor_factory(x, device="meta") return x def wrap(y, i=None): - if isinstance(y, MetaTensor): - if y in meta_to_lazy: - # inplace op, just return origin lazy tensor - return meta_to_lazy[y] + if isinstance(y, torch.Tensor): + if y.is_meta: + if y in meta_to_lazy: + # inplace op, just return origin lazy tensor + return meta_to_lazy[y] + else: + # out of place op, create new lazy tensor + fn = lambda *a, **kw: func(*a, **kw) if i is None else func(*a, **kw)[i] + fn.__name__ = func.__name__ + lazy_y = LazyTensor(fn, *args, meta_data=y, **kwargs) + return lazy_y else: - # out of place op, create new lazy tensor - fn = lambda *a, **kw: func(*a, **kw) if i is None else func(*a, **kw)[i] - fn.__name__ = func.__name__ - lazy_y = LazyTensor(fn, *args, meta_data=y, **kwargs) - return lazy_y - elif type(y) is Tensor: - # for early materialized tensor - return LazyTensor(lambda: None, concrete_data=y) + # for early materialized tensor + return LazyTensor(lambda: None, concrete_data=y) return y cls._pre_op_fn() - o = func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs)) + with ConstructorManager.disable(): + # to disable create lazy tensor in inner ops, this is a hack for torch 2.0 + o = func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs)) if isinstance(o, (tuple, list)): return type(o)(wrap(y, i=i) for i, y in enumerate(o)) return wrap(o) - @classmethod - def __torch_dispatch__(cls, func, types, args=(), kwargs=None): - pass # skip + def to(self, *args, **kwargs) -> torch.Tensor: + if self._materialized_data is not None: + return LazyTensor(lambda: None, concrete_data=self._materialized_data.to(*args, **kwargs)) + + device = None + + def replace(x): + nonlocal device + if isinstance(x, (str, int, torch.device)) and not isinstance(x, bool): + device = x + return torch.device("meta") + return x + + meta_data = self._meta_data.to(*tree_map(replace, args), **tree_map(replace, kwargs)) + + if meta_data is self._meta_data and device == self.device: + return self + + def factory_fn(t: torch.Tensor, **kw): + return t.to(*args, **kwargs) + + return LazyTensor(factory_fn, self, meta_data=meta_data, device=device) + + def cpu(self, memory_format: torch.memory_format = torch.preserve_format): + return self.to(device=torch.device("cpu"), memory_format=memory_format) + + def cuda(self, device=None, non_blocking=False, memory_format: torch.memory_format = torch.preserve_format): + device = torch.device(device or "cuda") + return self.to(device=device, non_blocking=non_blocking, memory_format=memory_format) def clone(self) -> "LazyTensor": - def factory_fn(): + def factory_fn(t: torch.Tensor, **kw): # if self is materialized, return self - new_tensor = self.materialize() if type(self) is LazyTensor else self - return new_tensor.clone() + return t.clone() - target = LazyTensor(factory_fn, meta_data=self._meta_data) + target = LazyTensor(factory_fn, self, meta_data=self._meta_data) return target @@ -353,17 +406,16 @@ class LazyTensor(torch.Tensor): if id(self) in memo: return memo[id(self)] - def factory_fn(): + def factory_fn(t: torch.Tensor, **kw): # if self is materialized, return self - new_tensor = self.materialize() if type(self) is LazyTensor else self - return _copy_tensor(new_tensor, new_tensor.requires_grad) + return _copy_tensor(t, t.requires_grad) if self._materialized_data is not None: # self is early materialized copied = _copy_tensor(self._materialized_data, self.requires_grad) target = LazyTensor(lambda: None, concrete_data=copied) else: - target = LazyTensor(factory_fn, meta_data=self._meta_data) + target = LazyTensor(factory_fn, self, meta_data=self._meta_data) if isinstance(self, Parameter): # hack isinstance check of parameter @@ -394,14 +446,12 @@ class LazyTensor(torch.Tensor): if other is self: return - self._op_buffer.append(other._factory_method) - def replace(x): if x is other: return self return x - for func, args, kwargs in other._op_buffer: + for func, args, kwargs in [other._factory_method, *other._op_buffer]: self._op_buffer.append((func, tree_map(replace, args), tree_map(replace, kwargs))) def tolist(self) -> list: @@ -455,7 +505,6 @@ class LazyInitContext: default_device: Optional[Union[torch.device, str, int]] = None, ): assert tensor_cls is LazyTensor or tensor_cls is _MyTensor - self.overrides = {} self.tensor_cls = tensor_cls self.old_default_device = LazyTensor.default_device self.default_device = default_device @@ -478,7 +527,9 @@ class LazyInitContext: # factory_like functions (eg. torch.empty_like()) def wrapper(*args, **kwargs): orig_t = args[0] - return self.tensor_cls(orig_target, *args[1:], device=orig_t.device, dtype=orig_t.dtype, **kwargs) + return self.tensor_cls( + orig_target, *orig_t.shape, *args[1:], device=orig_t.device, dtype=orig_t.dtype, **kwargs + ) return wrapper, target @@ -513,13 +564,13 @@ class LazyInitContext: return wrapper, target - self.overrides = { + overrides = { target: wrap_factory_method(getattr(torch, target)) for target in _NORMAL_FACTORY if callable(getattr(torch, target, None)) } - self.overrides.update( + overrides.update( { target + "_like": wrap_factory_like_method(getattr(torch, target), getattr(torch, target + "_like")) for target in _NORMAL_FACTORY @@ -527,7 +578,7 @@ class LazyInitContext: } ) - self.overrides.update( + overrides.update( { target: wrap_legacy_constructor(getattr(torch, target), dtype) for target, dtype in _LEGACY_TENSOR_CONSTRUCTOR.items() @@ -535,7 +586,7 @@ class LazyInitContext: } ) - self.overrides.update( + overrides.update( { target: wrap_no_meta_factory(getattr(torch, target)) for target in _NO_META_FACTORY @@ -543,14 +594,12 @@ class LazyInitContext: } ) - for name, (wrapper, orig) in self.overrides.items(): - setattr(torch, name, wrapper) + ConstructorManager.apply(overrides) def __exit__(self, exc_type, exc_val, exc_tb): self.tensor_cls.default_device = self.old_default_device LazyInitContext._replaced = False - for name, (wrapper, orig) in self.overrides.items(): - setattr(torch, name, orig) + ConstructorManager.clear() @staticmethod def materialize(module: nn.Module, verbose: bool = False) -> nn.Module: @@ -566,23 +615,6 @@ class LazyInitContext: return _apply_to_lazy_module(module, apply_fn, verbose) - @staticmethod - def distribute( - module: nn.Module, device_mesh: DeviceMesh, sharding_spec_dict: Dict[str, ShardingSpec], verbose: bool = False - ) -> nn.Module: - """Distribute all ``Parameter`` from ``LazyTensor``. This function will modify the module in-place. - - Args: - module (nn.Module): Target ``nn.Module`` - layout_dict (dict): Dict of layout for each parameter/buffer. The key is the parameter/buffer name, and the value is the layout. - verbose (bool, optional): Whether to print lazy initialization rate. Defaults to False. - """ - - def apply_fn(name: str, p: LazyTensor): - p.distribute(device_mesh, sharding_spec_dict[name]) - - return _apply_to_lazy_module(module, apply_fn, verbose) - def _apply_to_lazy_module( module: nn.Module, apply_fn: Callable[[str, torch.Tensor], None], verbose: bool = False @@ -622,20 +654,17 @@ def _apply_to_lazy_module( if verbose: non_lazy_numel_ratio = non_lazy_numel / total_numel * 100 if non_lazy_numel != 0 else 0 - _print_rank_0(f"Param lazy rate: {param_lazy_cnt}/{param_cnt}") - _print_rank_0(f"Buffer lazy rate: {buf_lazy_cnt}/{buf_cnt}") - _print_rank_0( - f"Non lazy numel: {non_lazy_numel} ({non_lazy_numel/1024**2:.3f} M), ratio: {non_lazy_numel_ratio}%" + logger = get_dist_logger() + logger.info(f"Param lazy rate: {param_lazy_cnt}/{param_cnt}", ranks=[0]) + logger.info(f"Buffer lazy rate: {buf_lazy_cnt}/{buf_cnt}", ranks=[0]) + logger.info( + f"Non lazy numel: {non_lazy_numel} ({non_lazy_numel/1024**2:.3f} M), ratio: {non_lazy_numel_ratio}%", + ranks=[0], ) return module -def _print_rank_0(*args, **kwargs): - if not dist.is_initialized() or dist.get_rank() == 0: - print(*args, **kwargs) - - def _is_int_tuple(args) -> bool: if not isinstance(args, tuple): return False diff --git a/tests/test_lazy/test_models.py b/tests/test_lazy/test_models.py index 978cf06b5..a1b5763d4 100644 --- a/tests/test_lazy/test_models.py +++ b/tests/test_lazy/test_models.py @@ -11,14 +11,12 @@ def test_torchvision_models_lazy_init(subset, default_device): sub_model_zoo = model_zoo.get_sub_registry(subset) for name, entry in sub_model_zoo.items(): # TODO(ver217): lazy init does not support weight norm, skip these models - if ( - name in ("torchaudio_wav2vec2_base", "torchaudio_hubert_base") - or name.startswith("transformers_llama") - or name.startswith(("transformers_vit", "transformers_blip2")) + if name in ("torchaudio_wav2vec2_base", "torchaudio_hubert_base") or name.startswith( + ("transformers_vit", "transformers_blip2") ): continue check_lazy_init(entry, verbose=True, default_device=default_device) if __name__ == "__main__": - test_torchvision_models_lazy_init("torchvision") + test_torchvision_models_lazy_init("transformers", "cpu") diff --git a/tests/test_lazy/test_ops.py b/tests/test_lazy/test_ops.py new file mode 100644 index 000000000..e6b936198 --- /dev/null +++ b/tests/test_lazy/test_ops.py @@ -0,0 +1,64 @@ +import copy + +import pytest +import torch +import torch.nn as nn +from lazy_init_utils import SUPPORT_LAZY +from torch.nn import Parameter + +from colossalai.lazy import LazyInitContext + + +@pytest.mark.skipif(not SUPPORT_LAZY, reason="requires torch >= 1.12.0") +def test_lazy_ops(): + with LazyInitContext(): + x = torch.rand(2, 3) + assert tuple(x.shape) == (2, 3) + assert x.device.type == "cpu" + x.requires_grad is False + y = x.cuda() + assert tuple(y.shape) == (2, 3) + assert y.device.type == "cuda" + assert y.requires_grad is False + assert x.cpu() is x + p = Parameter(torch.empty(2, 3)) + assert tuple(p.shape) == (2, 3) + assert p.device.type == "cpu" + assert p.requires_grad is True + assert isinstance(p, Parameter) + x.materialize() + assert tuple(x.shape) == (2, 3) + assert x.device.type == "cpu" + assert x.requires_grad is False + y.materialize() + assert tuple(y.shape) == (2, 3) + assert y.device.type == "cuda" + assert y.requires_grad is False + p.materialize() + assert tuple(p.shape) == (2, 3) + assert p.device.type == "cpu" + assert p.requires_grad is True + assert isinstance(p, Parameter) + + with LazyInitContext(): + x = torch.empty(2, 3) + x.uniform_() + x.materialize() + assert tuple(x.shape) == (2, 3) + + with LazyInitContext(): + model = nn.Linear(3, 4) + model = model.cuda() + model_copied = copy.deepcopy(model) + LazyInitContext.materialize(model) + assert model.weight.device.type == "cuda" + assert model.bias.device.type == "cuda" + LazyInitContext.materialize(model_copied) + assert model_copied.weight.device.type == "cuda" + assert model_copied.bias.device.type == "cuda" + assert torch.equal(model.weight, model_copied.weight) + assert torch.equal(model.bias, model_copied.bias) + + +if __name__ == "__main__": + test_lazy_ops()