[gemini] gemini supports lazy init (#3379)

* [gemini] fix nvme optimizer init

* [gemini] gemini supports lazy init

* [gemini] add init example

* [gemini] add fool model

* [zero] update gemini ddp

* [zero] update init example

* add chunk method

* add chunk method

* [lazyinit] fix lazy tensor tolist

* [gemini] fix buffer materialization

* [misc] remove useless file

* [booster] update gemini plugin

* [test] update gemini plugin test

* [test] fix gemini plugin test

* [gemini] fix import

* [gemini] fix import

* [lazyinit] use new metatensor

* [lazyinit] use new metatensor

* [lazyinit] fix __set__ method
This commit is contained in:
Hongxin Liu
2023-04-12 16:03:25 +08:00
committed by GitHub
parent 366a035552
commit 152239bbfa
7 changed files with 80 additions and 72 deletions

View File

@@ -7,7 +7,7 @@ import torch.nn as nn
from torch import Tensor
from torch.utils._pytree import tree_map
from colossalai.fx.profiler.tensor import MetaTensor
from colossalai._analyzer._subclasses import MetaTensor
from colossalai.tensor.d_tensor.d_tensor import DTensor
from colossalai.tensor.d_tensor.layout import Layout
@@ -37,7 +37,7 @@ _EARLY_MATERIALIZED_OPS = ['__getitem__', 'split']
# If your intent is to change the metadata of a Tensor (such as sizes / strides / storage / storage_offset)
# without autograd tracking the change, remove the .data / .detach() call and wrap the change in a `with torch.no_grad():` block.
# These ops cannot be unwrapped using .data
_CHANGE_META_OPS = ['_cudnn_rnn_flatten_weight', 'requires_grad_', '__get__']
_CHANGE_META_OPS = ['_cudnn_rnn_flatten_weight', 'requires_grad_', '__get__', '__set__']
_LEGACY_TENSOR_CONSTRUCTOR = {
'FloatTensor': torch.float,
@@ -75,6 +75,12 @@ class _MyTensor(Tensor):
return super().__torch_function__(func, types, args, kwargs)
def _data_tolist(tensor: torch.Tensor) -> list:
"""tolist() method is not allowed for a subclass of tensor. Tensor.data returns a Tensor.
"""
return tensor.data.tolist()
def _convert_cls(tensor: 'LazyTensor', target: torch.Tensor) -> torch.Tensor:
"""Convert a lazy tensor's class to target's class, with target's data.
@@ -94,7 +100,7 @@ def _convert_cls(tensor: 'LazyTensor', target: torch.Tensor) -> torch.Tensor:
tensor.requires_grad = target.requires_grad
# subclass of torch.Tensor does not have tolist() method
# overwrite this method after materialization or distribution
tensor.tolist = MethodType(torch.Tensor.tolist, target)
tensor.tolist = MethodType(_data_tolist, tensor)
return tensor
@@ -144,7 +150,7 @@ class LazyTensor(torch.Tensor):
if meta_data is None:
device = kwargs.get('device', 'cpu')
elem = func(*args, **{**kwargs, 'device': 'meta'})
meta_data = MetaTensor(elem, fake_device=device)
meta_data = MetaTensor(elem, device=device)
elem = meta_data._tensor
# As a meta tensor cannot be modified __class__ to torch.Tensor, we should use an empty real tensor here
r = torch.Tensor._make_subclass(cls, _EMPTY_DATA, require_grad=elem.requires_grad)
@@ -255,7 +261,7 @@ class LazyTensor(torch.Tensor):
tree_map(cls._replace_with_materialized, args)
tree_map(cls._replace_with_materialized, kwargs)
is_inplace: bool = (func.__name__.endswith('_') and not (func.__name__.endswith('__'))
or func.__name__ == "__setitem__")
or func.__name__ in ('__setitem__', '__set__'))
is_change_meta_op: bool = func.__name__ in _CHANGE_META_OPS