From fc5cef2c79265e36b585ef22c5e1d7f18be52a4e Mon Sep 17 00:00:00 2001 From: Hongxin Liu Date: Wed, 19 Jul 2023 16:43:01 +0800 Subject: [PATCH] [lazy] support init on cuda (#4269) * [lazy] support init on cuda * [test] update lazy init test * [test] fix transformer version --- colossalai/lazy/lazy_init.py | 28 ++++++++++++++++++++-------- requirements/requirements-test.txt | 2 +- tests/test_lazy/lazy_init_utils.py | 10 +++++++--- tests/test_lazy/test_models.py | 5 +++-- 4 files changed, 31 insertions(+), 14 deletions(-) diff --git a/colossalai/lazy/lazy_init.py b/colossalai/lazy/lazy_init.py index 8b9114073..1f5345015 100644 --- a/colossalai/lazy/lazy_init.py +++ b/colossalai/lazy/lazy_init.py @@ -1,3 +1,4 @@ +from contextlib import contextmanager from types import MethodType from typing import Callable, Dict, Optional, Union @@ -61,12 +62,15 @@ class _MyTensor(Tensor): """ _pre_op_fn: Callable[['LazyTensor'], None] = lambda *args: None + default_device: Optional[torch.device] = None + def __new__(cls, func, *args, concrete_data=None, **kwargs) -> '_MyTensor': cls._pre_op_fn() if concrete_data is not None: # uniform api as LazyTensor data = concrete_data else: + kwargs['device'] = cls.default_device data = func(*args, **kwargs) return Tensor._make_subclass(cls, data, require_grad=data.requires_grad) @@ -142,6 +146,8 @@ class LazyTensor(torch.Tensor): _meta_data: Optional[MetaTensor] = None # shape, dtype, device _pre_op_fn: Callable[['LazyTensor'], None] = lambda *args: None + default_device: Optional[torch.device] = None + @staticmethod def __new__(cls, func, *args, meta_data=None, concrete_data=None, **kwargs): if concrete_data is not None: @@ -159,6 +165,8 @@ class LazyTensor(torch.Tensor): return r def __init__(self, func, *args, meta_data=None, concrete_data=None, **kwargs): + if func.__name__ in _NORMAL_FACTORY: + kwargs = {**kwargs, 'device': LazyTensor.default_device} self._factory_method = (func, args, kwargs) # (func, args, kwargs) self._op_buffer = [] # (func, args, kwargs, replace) self._materialized_data: Optional[torch.Tensor] = concrete_data # materialized data @@ -206,16 +214,11 @@ class LazyTensor(torch.Tensor): if self._materialized_data is None: # apply factory method func, args, kwargs = self._factory_method - # apply cached sequence self._pre_op_fn() - try: - init_val = func(*tree_map(self._replace_with_materialized, args), - **tree_map(self._replace_with_materialized, kwargs)) - except TypeError as e: - print(f'init fn: {func.__name__}') - raise e + init_val = func(*tree_map(self._replace_with_materialized, args), + **tree_map(self._replace_with_materialized, kwargs)) self._materialized_data = self._rerun_ops(init_val) return self._materialized_data @@ -305,6 +308,7 @@ class LazyTensor(torch.Tensor): else: # out of place op, create new lazy tensor fn = lambda *a, **kw: func(*a, **kw) if i is None else func(*a, **kw)[i] + fn.__name__ = func.__name__ lazy_y = LazyTensor(fn, *args, meta_data=y, **kwargs) return lazy_y elif type(y) is Tensor: @@ -435,14 +439,21 @@ class LazyInitContext: """ _replaced: bool = False - def __init__(self, tensor_cls: Union[_MyTensor, LazyTensor] = LazyTensor): + def __init__(self, + tensor_cls: Union[_MyTensor, LazyTensor] = LazyTensor, + default_device: Optional[Union[torch.device, str, int]] = None): + assert tensor_cls is LazyTensor or tensor_cls is _MyTensor self.overrides = {} self.tensor_cls = tensor_cls + self.old_default_device = LazyTensor.default_device + self.default_device = default_device def __enter__(self): if LazyInitContext._replaced: raise RuntimeError(f'LazyInitContext is not reentrant') LazyInitContext._replaced = True + self.old_default_device = self.tensor_cls.default_device + self.tensor_cls.default_device = self.default_device def wrap_factory_method(target): # factory functions (eg. torch.empty()) @@ -518,6 +529,7 @@ class LazyInitContext: setattr(torch, name, wrapper) def __exit__(self, exc_type, exc_val, exc_tb): + self.tensor_cls.default_device = self.old_default_device LazyInitContext._replaced = False for name, (wrapper, orig) in self.overrides.items(): setattr(torch, name, orig) diff --git a/requirements/requirements-test.txt b/requirements/requirements-test.txt index 50121a928..9f6580c72 100644 --- a/requirements/requirements-test.txt +++ b/requirements/requirements-test.txt @@ -4,7 +4,7 @@ pytest coverage==7.2.3 git+https://github.com/hpcaitech/pytest-testmon torchvision -transformers +transformers==4.30.2 timm titans torchaudio diff --git a/tests/test_lazy/lazy_init_utils.py b/tests/test_lazy/lazy_init_utils.py index 73c3c5422..9d9e9a3a5 100644 --- a/tests/test_lazy/lazy_init_utils.py +++ b/tests/test_lazy/lazy_init_utils.py @@ -61,14 +61,18 @@ def assert_forward_equal(m1: torch.nn.Module, m2: torch.nn.Module, data_gen_fn: f'{m1.__class__.__name__} has inconsistent outputs, {out1} vs {out2}' -def check_lazy_init(entry: TestingEntry, seed: int = 42, verbose: bool = False, check_forward: bool = False) -> None: +def check_lazy_init(entry: TestingEntry, + seed: int = 42, + verbose: bool = False, + check_forward: bool = False, + default_device: str = 'cpu') -> None: model_fn, data_gen_fn, output_transform_fn, _, model_attr = entry _MyTensor._pre_op_fn = lambda *args: set_seed(seed) LazyTensor._pre_op_fn = lambda *args: set_seed(seed) - ctx = LazyInitContext(tensor_cls=_MyTensor) + ctx = LazyInitContext(tensor_cls=_MyTensor, default_device=default_device) with ctx: model = model_fn() - ctx = LazyInitContext() + ctx = LazyInitContext(default_device=default_device) with ctx: deferred_model = model_fn() copied_deferred_model = deepcopy(deferred_model) diff --git a/tests/test_lazy/test_models.py b/tests/test_lazy/test_models.py index 4b7aeed73..e37184125 100644 --- a/tests/test_lazy/test_models.py +++ b/tests/test_lazy/test_models.py @@ -6,13 +6,14 @@ from tests.kit.model_zoo import model_zoo @pytest.mark.skipif(not SUPPORT_LAZY, reason='requires torch >= 1.12.0') @pytest.mark.parametrize('subset', ['torchvision', 'diffusers', 'timm', 'transformers', 'torchaudio', 'deepfm', 'dlrm']) -def test_torchvision_models_lazy_init(subset): +@pytest.mark.parametrize('default_device', ['cpu', 'cuda']) +def test_torchvision_models_lazy_init(subset, default_device): sub_model_zoo = model_zoo.get_sub_registry(subset) for name, entry in sub_model_zoo.items(): # TODO(ver217): lazy init does not support weight norm, skip these models if name in ('torchaudio_wav2vec2_base', 'torchaudio_hubert_base') or name.startswith('transformers_llama'): continue - check_lazy_init(entry, verbose=True) + check_lazy_init(entry, verbose=True, default_device=default_device) if __name__ == '__main__':