mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 19:13:01 +00:00
[gemini] gemini supports lazy init (#3379)
* [gemini] fix nvme optimizer init * [gemini] gemini supports lazy init * [gemini] add init example * [gemini] add fool model * [zero] update gemini ddp * [zero] update init example * add chunk method * add chunk method * [lazyinit] fix lazy tensor tolist * [gemini] fix buffer materialization * [misc] remove useless file * [booster] update gemini plugin * [test] update gemini plugin test * [test] fix gemini plugin test * [gemini] fix import * [gemini] fix import * [lazyinit] use new metatensor * [lazyinit] use new metatensor * [lazyinit] fix __set__ method
This commit is contained in:
@@ -7,7 +7,7 @@ import torch.nn as nn
|
||||
from torch import Tensor
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
from colossalai.fx.profiler.tensor import MetaTensor
|
||||
from colossalai._analyzer._subclasses import MetaTensor
|
||||
from colossalai.tensor.d_tensor.d_tensor import DTensor
|
||||
from colossalai.tensor.d_tensor.layout import Layout
|
||||
|
||||
@@ -37,7 +37,7 @@ _EARLY_MATERIALIZED_OPS = ['__getitem__', 'split']
|
||||
# If your intent is to change the metadata of a Tensor (such as sizes / strides / storage / storage_offset)
|
||||
# without autograd tracking the change, remove the .data / .detach() call and wrap the change in a `with torch.no_grad():` block.
|
||||
# These ops cannot be unwrapped using .data
|
||||
_CHANGE_META_OPS = ['_cudnn_rnn_flatten_weight', 'requires_grad_', '__get__']
|
||||
_CHANGE_META_OPS = ['_cudnn_rnn_flatten_weight', 'requires_grad_', '__get__', '__set__']
|
||||
|
||||
_LEGACY_TENSOR_CONSTRUCTOR = {
|
||||
'FloatTensor': torch.float,
|
||||
@@ -75,6 +75,12 @@ class _MyTensor(Tensor):
|
||||
return super().__torch_function__(func, types, args, kwargs)
|
||||
|
||||
|
||||
def _data_tolist(tensor: torch.Tensor) -> list:
|
||||
"""tolist() method is not allowed for a subclass of tensor. Tensor.data returns a Tensor.
|
||||
"""
|
||||
return tensor.data.tolist()
|
||||
|
||||
|
||||
def _convert_cls(tensor: 'LazyTensor', target: torch.Tensor) -> torch.Tensor:
|
||||
"""Convert a lazy tensor's class to target's class, with target's data.
|
||||
|
||||
@@ -94,7 +100,7 @@ def _convert_cls(tensor: 'LazyTensor', target: torch.Tensor) -> torch.Tensor:
|
||||
tensor.requires_grad = target.requires_grad
|
||||
# subclass of torch.Tensor does not have tolist() method
|
||||
# overwrite this method after materialization or distribution
|
||||
tensor.tolist = MethodType(torch.Tensor.tolist, target)
|
||||
tensor.tolist = MethodType(_data_tolist, tensor)
|
||||
return tensor
|
||||
|
||||
|
||||
@@ -144,7 +150,7 @@ class LazyTensor(torch.Tensor):
|
||||
if meta_data is None:
|
||||
device = kwargs.get('device', 'cpu')
|
||||
elem = func(*args, **{**kwargs, 'device': 'meta'})
|
||||
meta_data = MetaTensor(elem, fake_device=device)
|
||||
meta_data = MetaTensor(elem, device=device)
|
||||
elem = meta_data._tensor
|
||||
# As a meta tensor cannot be modified __class__ to torch.Tensor, we should use an empty real tensor here
|
||||
r = torch.Tensor._make_subclass(cls, _EMPTY_DATA, require_grad=elem.requires_grad)
|
||||
@@ -255,7 +261,7 @@ class LazyTensor(torch.Tensor):
|
||||
tree_map(cls._replace_with_materialized, args)
|
||||
tree_map(cls._replace_with_materialized, kwargs)
|
||||
is_inplace: bool = (func.__name__.endswith('_') and not (func.__name__.endswith('__'))
|
||||
or func.__name__ == "__setitem__")
|
||||
or func.__name__ in ('__setitem__', '__set__'))
|
||||
|
||||
is_change_meta_op: bool = func.__name__ in _CHANGE_META_OPS
|
||||
|
||||
|
Reference in New Issue
Block a user