[lazy] support init on cuda (#4269)

* [lazy] support init on cuda

* [test] update lazy init test

* [test] fix transformer version
This commit is contained in:
Hongxin Liu
2023-07-19 16:43:01 +08:00
committed by GitHub
parent 4b977541a8
commit fc5cef2c79
4 changed files with 31 additions and 14 deletions

View File

@@ -1,3 +1,4 @@
from contextlib import contextmanager
from types import MethodType
from typing import Callable, Dict, Optional, Union
@@ -61,12 +62,15 @@ class _MyTensor(Tensor):
"""
_pre_op_fn: Callable[['LazyTensor'], None] = lambda *args: None
default_device: Optional[torch.device] = None
def __new__(cls, func, *args, concrete_data=None, **kwargs) -> '_MyTensor':
cls._pre_op_fn()
if concrete_data is not None:
# uniform api as LazyTensor
data = concrete_data
else:
kwargs['device'] = cls.default_device
data = func(*args, **kwargs)
return Tensor._make_subclass(cls, data, require_grad=data.requires_grad)
@@ -142,6 +146,8 @@ class LazyTensor(torch.Tensor):
_meta_data: Optional[MetaTensor] = None # shape, dtype, device
_pre_op_fn: Callable[['LazyTensor'], None] = lambda *args: None
default_device: Optional[torch.device] = None
@staticmethod
def __new__(cls, func, *args, meta_data=None, concrete_data=None, **kwargs):
if concrete_data is not None:
@@ -159,6 +165,8 @@ class LazyTensor(torch.Tensor):
return r
def __init__(self, func, *args, meta_data=None, concrete_data=None, **kwargs):
if func.__name__ in _NORMAL_FACTORY:
kwargs = {**kwargs, 'device': LazyTensor.default_device}
self._factory_method = (func, args, kwargs) # (func, args, kwargs)
self._op_buffer = [] # (func, args, kwargs, replace)
self._materialized_data: Optional[torch.Tensor] = concrete_data # materialized data
@@ -206,16 +214,11 @@ class LazyTensor(torch.Tensor):
if self._materialized_data is None:
# apply factory method
func, args, kwargs = self._factory_method
# apply cached sequence
self._pre_op_fn()
try:
init_val = func(*tree_map(self._replace_with_materialized, args),
**tree_map(self._replace_with_materialized, kwargs))
except TypeError as e:
print(f'init fn: {func.__name__}')
raise e
init_val = func(*tree_map(self._replace_with_materialized, args),
**tree_map(self._replace_with_materialized, kwargs))
self._materialized_data = self._rerun_ops(init_val)
return self._materialized_data
@@ -305,6 +308,7 @@ class LazyTensor(torch.Tensor):
else:
# out of place op, create new lazy tensor
fn = lambda *a, **kw: func(*a, **kw) if i is None else func(*a, **kw)[i]
fn.__name__ = func.__name__
lazy_y = LazyTensor(fn, *args, meta_data=y, **kwargs)
return lazy_y
elif type(y) is Tensor:
@@ -435,14 +439,21 @@ class LazyInitContext:
"""
_replaced: bool = False
def __init__(self, tensor_cls: Union[_MyTensor, LazyTensor] = LazyTensor):
def __init__(self,
tensor_cls: Union[_MyTensor, LazyTensor] = LazyTensor,
default_device: Optional[Union[torch.device, str, int]] = None):
assert tensor_cls is LazyTensor or tensor_cls is _MyTensor
self.overrides = {}
self.tensor_cls = tensor_cls
self.old_default_device = LazyTensor.default_device
self.default_device = default_device
def __enter__(self):
if LazyInitContext._replaced:
raise RuntimeError(f'LazyInitContext is not reentrant')
LazyInitContext._replaced = True
self.old_default_device = self.tensor_cls.default_device
self.tensor_cls.default_device = self.default_device
def wrap_factory_method(target):
# factory functions (eg. torch.empty())
@@ -518,6 +529,7 @@ class LazyInitContext:
setattr(torch, name, wrapper)
def __exit__(self, exc_type, exc_val, exc_tb):
self.tensor_cls.default_device = self.old_default_device
LazyInitContext._replaced = False
for name, (wrapper, orig) in self.overrides.items():
setattr(torch, name, orig)