mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2026-05-05 12:24:38 +00:00
[lazyinit] combine lazy tensor with dtensor (#3204)
* [lazyinit] lazy tensor add distribute * [lazyinit] refactor distribute * [lazyinit] add test dist lazy init * [lazyinit] add verbose info for dist lazy init * [lazyinit] fix rnn flatten weight op * [lazyinit] polish test * [lazyinit] polish test * [lazyinit] fix lazy tensor data setter * [lazyinit] polish test * [lazyinit] fix clean * [lazyinit] make materialize inplace * [lazyinit] refactor materialize * [lazyinit] refactor test distribute * [lazyinit] fix requires_grad * [lazyinit] fix tolist after materialization * [lazyinit] refactor distribute module * [lazyinit] polish docstr * [lazyinit] polish lazy init context * [lazyinit] temporarily skip test * [lazyinit] polish test * [lazyinit] add docstr
This commit is contained in:
@@ -1,11 +1,15 @@
|
||||
from typing import Callable, List, Optional, Union
|
||||
from types import MethodType
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
from colossalai.fx.profiler.tensor import MetaTensor
|
||||
from colossalai.tensor.d_tensor.d_tensor import DTensor
|
||||
from colossalai.tensor.d_tensor.layout import Layout
|
||||
|
||||
# reference: https://pytorch.org/cppdocs/notes/tensor_creation.html
|
||||
_NORMAL_FACTORY = [
|
||||
@@ -30,6 +34,11 @@ _NO_META_FACTORY = [
|
||||
|
||||
_EARLY_MATERIALIZED_OPS = ['__getitem__', 'split']
|
||||
|
||||
# If your intent is to change the metadata of a Tensor (such as sizes / strides / storage / storage_offset)
|
||||
# without autograd tracking the change, remove the .data / .detach() call and wrap the change in a `with torch.no_grad():` block.
|
||||
# These ops cannot be unwrapped using .data
|
||||
_CHANGE_META_OPS = ['_cudnn_rnn_flatten_weight', 'requires_grad_', '__get__']
|
||||
|
||||
_LEGACY_TENSOR_CONSTRUCTOR = {
|
||||
'FloatTensor': torch.float,
|
||||
'DoubleTensor': torch.double,
|
||||
@@ -43,6 +52,8 @@ _LEGACY_TENSOR_CONSTRUCTOR = {
|
||||
'BoolTensor': torch.bool,
|
||||
}
|
||||
|
||||
_EMPTY_DATA = torch.empty(0)
|
||||
|
||||
|
||||
class _MyTensor(Tensor):
|
||||
"""This class is only for correctness verification.
|
||||
@@ -64,6 +75,29 @@ class _MyTensor(Tensor):
|
||||
return super().__torch_function__(func, types, args, kwargs)
|
||||
|
||||
|
||||
def _convert_cls(tensor: 'LazyTensor', target: torch.Tensor) -> torch.Tensor:
|
||||
"""Convert a lazy tensor's class to target's class, with target's data.
|
||||
|
||||
The reason why we change the class of a lazy tensor in-place is that this can easily handle shared modules/parameters, which is common in huggingface models.
|
||||
If we create a new tensor and update the module by ``setattr(module, name, param)``, the shared parameters will not be updated. And we have to track all shared parameters and update them manually.
|
||||
|
||||
Args:
|
||||
tensor (LazyTensor): the LazyTensor to be converted
|
||||
target (torch.Tensor): target tensor
|
||||
|
||||
Returns:
|
||||
torch.Tensor: the converted tensor
|
||||
"""
|
||||
cls_to_become = nn.Parameter if isinstance(tensor, nn.Parameter) else torch.Tensor
|
||||
tensor.__class__ = cls_to_become
|
||||
tensor.data = target
|
||||
tensor.requires_grad = target.requires_grad
|
||||
# subclass of torch.Tensor does not have tolist() method
|
||||
# overwrite this method after materialization or distribution
|
||||
tensor.tolist = MethodType(torch.Tensor.tolist, target)
|
||||
return tensor
|
||||
|
||||
|
||||
class LazyTensor(torch.Tensor):
|
||||
"""A naive implementation of LazyTensor (https://arxiv.org/pdf/2102.13267.pdf).
|
||||
|
||||
@@ -112,14 +146,8 @@ class LazyTensor(torch.Tensor):
|
||||
elem = func(*args, **{**kwargs, 'device': 'meta'})
|
||||
meta_data = MetaTensor(elem, fake_device=device)
|
||||
elem = meta_data._tensor
|
||||
r = torch.Tensor._make_wrapper_subclass(cls,
|
||||
elem.size(),
|
||||
strides=elem.stride(),
|
||||
storage_offset=elem.storage_offset(),
|
||||
dtype=elem.dtype,
|
||||
layout=elem.layout,
|
||||
device=elem.device,
|
||||
requires_grad=elem.requires_grad)
|
||||
# As a meta tensor cannot be modified __class__ to torch.Tensor, we should use an empty real tensor here
|
||||
r = torch.Tensor._make_subclass(cls, _EMPTY_DATA, require_grad=elem.requires_grad)
|
||||
r._meta_data = meta_data
|
||||
return r
|
||||
|
||||
@@ -129,15 +157,28 @@ class LazyTensor(torch.Tensor):
|
||||
self._materialized_data: Optional[torch.Tensor] = concrete_data # materialized data
|
||||
|
||||
def materialize(self) -> torch.Tensor:
|
||||
"""Materialize the ``LazyTensor`` to ``torch.Tensor``.
|
||||
"""Materialize the ``LazyTensor`` to ``torch.Tensor`` by modifying __class__ (inplace).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The materialized tensor.
|
||||
torch.Tensor: The materialized tensor (self).
|
||||
"""
|
||||
target = self._materialize_data()
|
||||
if isinstance(self, nn.Parameter):
|
||||
target = nn.Parameter(target, requires_grad=self.requires_grad)
|
||||
return target
|
||||
self.clean()
|
||||
return _convert_cls(self, target)
|
||||
|
||||
def distribute(self, layout: Layout) -> torch.Tensor:
|
||||
"""Distribute the ``LazyTensor`` to ``torch.Tensor`` by modifying __class__ (inplace), according to the layout.
|
||||
|
||||
Args:
|
||||
layout (Layout): Distribution layout.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The distributed tensor (self).
|
||||
"""
|
||||
target = self._materialize_data()
|
||||
self.clean()
|
||||
local_tensor = DTensor(target, layout).local_tensor
|
||||
return _convert_cls(self, local_tensor)
|
||||
|
||||
def clean(self) -> None:
|
||||
"""Clean all stored operations, meta data and materialized data, which prevents memory leaking. This should be called after all tensors are materialized.
|
||||
@@ -216,6 +257,8 @@ class LazyTensor(torch.Tensor):
|
||||
is_inplace: bool = (func.__name__.endswith('_') and not (func.__name__.endswith('__'))
|
||||
or func.__name__ == "__setitem__")
|
||||
|
||||
is_change_meta_op: bool = func.__name__ in _CHANGE_META_OPS
|
||||
|
||||
if isinstance(func, torch._C.ScriptMethod):
|
||||
# FIXME(ver217): torch script functions are not verified
|
||||
|
||||
@@ -239,10 +282,10 @@ class LazyTensor(torch.Tensor):
|
||||
if isinstance(x, LazyTensor):
|
||||
if x._materialized_data is not None:
|
||||
# for early materialized tensor, use its materialized data directly
|
||||
return x._materialized_data.data
|
||||
return x._materialized_data if is_change_meta_op else x._materialized_data.data
|
||||
t = x if is_inplace else x.clone()
|
||||
t._op_buffer.append((func, args, kwargs))
|
||||
meta = x._meta_data.data
|
||||
meta = x._meta_data if is_change_meta_op else x._meta_data.data
|
||||
meta_to_lazy[meta] = t
|
||||
return meta
|
||||
return x
|
||||
@@ -290,13 +333,36 @@ class LazyTensor(torch.Tensor):
|
||||
|
||||
@data.setter
|
||||
def data(self, other: 'LazyTensor'):
|
||||
"""This is sightly different from oringinal `data` setter.
|
||||
|
||||
E.g.:
|
||||
>>> a = torch.randn(3, 3) # a is a Tensor
|
||||
>>> b = torch.rand(2, 2)
|
||||
>>> a.data = b
|
||||
>>> b.add_(1) # this will affect a
|
||||
>>> x = torch.randn(3, 3) # x is a LazyTensor
|
||||
>>> y = torch.rand(2, 2) # y is a LazyTensor
|
||||
>>> x.data = y
|
||||
>>> y.add_(1) # this will not affect x
|
||||
|
||||
"""
|
||||
if other is self:
|
||||
return
|
||||
# TODO(ver217): to avoid infinity recursion, do early materialization
|
||||
self._materialized_data = other._materialize_data()
|
||||
|
||||
self._op_buffer.append(other._factory_method)
|
||||
|
||||
def replace(x):
|
||||
if x is other:
|
||||
return self
|
||||
return x
|
||||
|
||||
for func, args, kwargs in other._op_buffer:
|
||||
self._op_buffer.append((func, tree_map(replace, args), tree_map(replace, kwargs)))
|
||||
|
||||
def tolist(self) -> list:
|
||||
t = self.materialize()
|
||||
# Though self.__class__ is modified to torch.Tensor, in C++ side, it is still a subclass of torch.Tensor
|
||||
# And subclass of torch.Tensor does not have tolist() method
|
||||
t = self._materialize_data()
|
||||
return t.tolist()
|
||||
|
||||
def __hash__(self):
|
||||
@@ -421,71 +487,84 @@ class LazyInitContext:
|
||||
setattr(torch, name, orig)
|
||||
|
||||
@staticmethod
|
||||
def materialize(module: torch.nn.Module, verbose: bool = False):
|
||||
"""Initialize all ``nn.Parameter`` from ``LazyTensor``.
|
||||
def materialize(module: nn.Module, verbose: bool = False) -> nn.Module:
|
||||
"""Initialize all ``nn.Parameter`` from ``LazyTensor``. This function will modify the module in-place.
|
||||
|
||||
Args:
|
||||
module (torch.nn.Module): Target ``nn.Module``
|
||||
module (nn.Module): Target ``nn.Module``
|
||||
verbose (bool): Whether to print lazy initialization rate. Defaults to False.
|
||||
"""
|
||||
|
||||
def apply_fn(name: str, p: LazyTensor):
|
||||
p.materialize()
|
||||
|
||||
return _apply_to_lazy_module(module, apply_fn, verbose)
|
||||
|
||||
@staticmethod
|
||||
def distribute(module: nn.Module, layout_dict: dict, verbose: bool = False) -> nn.Module:
|
||||
"""Distribute all ``nn.Parameter`` from ``LazyTensor``. This function will modify the module in-place.
|
||||
|
||||
Args:
|
||||
module (nn.Module): Target ``nn.Module``
|
||||
layout_dict (dict): Dict of layout for each parameter/buffer. The key is the parameter/buffer name, and the value is the layout.
|
||||
verbose (bool, optional): Whether to print lazy initialization rate. Defaults to False.
|
||||
"""
|
||||
|
||||
def apply_fn(name: str, p: LazyTensor):
|
||||
p.distribute(layout_dict[name])
|
||||
|
||||
return _apply_to_lazy_module(module, apply_fn, verbose)
|
||||
|
||||
|
||||
def _apply_to_lazy_module(module: nn.Module,
|
||||
apply_fn: Callable[[str, torch.Tensor], None],
|
||||
verbose: bool = False) -> nn.Module:
|
||||
if verbose:
|
||||
# verbose info
|
||||
param_cnt = 0
|
||||
param_lazy_cnt = 0
|
||||
buf_cnt = 0
|
||||
buf_lazy_cnt = 0
|
||||
total_numel = 0
|
||||
non_lazy_numel = 0
|
||||
|
||||
for name, p in module.named_parameters():
|
||||
if verbose:
|
||||
param_cnt = 0
|
||||
param_lazy_cnt = 0
|
||||
buf_cnt = 0
|
||||
buf_lazy_cnt = 0
|
||||
non_lazy_numel = 0
|
||||
|
||||
# do post cleaning to handle shared parameter
|
||||
visited_lazy_tensors: List[LazyTensor] = []
|
||||
# handle shared module
|
||||
visited_modules = set()
|
||||
|
||||
@torch.no_grad()
|
||||
def init_recursively(module: nn.Module):
|
||||
nonlocal param_cnt, param_lazy_cnt, buf_cnt, buf_lazy_cnt, non_lazy_numel
|
||||
# recursively initialize the module
|
||||
for mod in module.children():
|
||||
if id(mod) not in visited_modules:
|
||||
visited_modules.add(id(mod))
|
||||
init_recursively(mod)
|
||||
|
||||
# initialize tensors directly attached to the current module
|
||||
for name, param in module.named_parameters(recurse=False):
|
||||
if verbose:
|
||||
param_cnt += 1
|
||||
if getattr(param, '_materialized_data', False) is None:
|
||||
# if no _materialized_data attr, the tensor is not lazy
|
||||
param_lazy_cnt += 1
|
||||
else:
|
||||
non_lazy_numel += param.numel()
|
||||
if hasattr(param, 'materialize'):
|
||||
# TODO(ver217): apex layers cannot be captured
|
||||
visited_lazy_tensors.append(param)
|
||||
setattr(module, name, param.materialize())
|
||||
|
||||
for name, buf in module.named_buffers(recurse=False):
|
||||
if verbose:
|
||||
buf_cnt += 1
|
||||
if getattr(buf, "_materialized_data", False) is None:
|
||||
# if no _materialized_data attr, the tensor is not lazy
|
||||
buf_lazy_cnt += 1
|
||||
else:
|
||||
non_lazy_numel += buf.numel()
|
||||
if hasattr(buf, 'materialize'):
|
||||
# TODO(ver217): apex layers cannot be captured
|
||||
visited_lazy_tensors.append(buf)
|
||||
setattr(module, name, buf.materialize())
|
||||
|
||||
init_recursively(module)
|
||||
|
||||
for t in visited_lazy_tensors:
|
||||
t.clean()
|
||||
param_cnt += 1
|
||||
total_numel += p.numel()
|
||||
if getattr(p, '_materialized_data', False) is None:
|
||||
# if no _materialized_data attr, the tensor is not lazy
|
||||
param_lazy_cnt += 1
|
||||
else:
|
||||
non_lazy_numel += p.numel()
|
||||
if isinstance(p, LazyTensor):
|
||||
apply_fn(name, p)
|
||||
|
||||
for name, buf in module.named_buffers():
|
||||
if verbose:
|
||||
print(f'Param lazy rate: {param_lazy_cnt}/{param_cnt}')
|
||||
print(f'Buffer lazy rate: {buf_lazy_cnt}/{buf_cnt}')
|
||||
print(f'Non-lazy numel: {non_lazy_numel} ({non_lazy_numel/1024**2:.3f} M)')
|
||||
return module
|
||||
buf_cnt += 1
|
||||
total_numel += buf.numel()
|
||||
if getattr(buf, "_materialized_data", False) is None:
|
||||
# if no _materialized_data attr, the tensor is not lazy
|
||||
buf_lazy_cnt += 1
|
||||
else:
|
||||
non_lazy_numel += buf.numel()
|
||||
if isinstance(buf, LazyTensor):
|
||||
apply_fn(name, buf)
|
||||
|
||||
if verbose:
|
||||
non_lazy_numel_ratio = non_lazy_numel / total_numel * 100 if non_lazy_numel != 0 else 0
|
||||
_print_rank_0(f'Param lazy rate: {param_lazy_cnt}/{param_cnt}')
|
||||
_print_rank_0(f'Buffer lazy rate: {buf_lazy_cnt}/{buf_cnt}')
|
||||
_print_rank_0(
|
||||
f'Non lazy numel: {non_lazy_numel} ({non_lazy_numel/1024**2:.3f} M), ratio: {non_lazy_numel_ratio}%')
|
||||
|
||||
return module
|
||||
|
||||
|
||||
def _print_rank_0(*args, **kwargs):
|
||||
if not dist.is_initialized() or dist.get_rank() == 0:
|
||||
print(*args, **kwargs)
|
||||
|
||||
|
||||
def _is_int_tuple(args) -> bool:
|
||||
|
||||
Reference in New Issue
Block a user