mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 10:06:44 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
from .lazy_init import LazyInitContext, LazyTensor
|
||||
|
||||
__all__ = [
|
||||
'LazyInitContext',
|
||||
'LazyTensor',
|
||||
"LazyInitContext",
|
||||
"LazyTensor",
|
||||
]
|
||||
|
@@ -1,4 +1,3 @@
|
||||
from contextlib import contextmanager
|
||||
from types import MethodType
|
||||
from typing import Callable, Dict, Optional, Union
|
||||
|
||||
@@ -35,43 +34,43 @@ _NO_META_FACTORY = [
|
||||
"eye",
|
||||
]
|
||||
|
||||
_EARLY_MATERIALIZED_OPS = ['__getitem__', 'split']
|
||||
_EARLY_MATERIALIZED_OPS = ["__getitem__", "split"]
|
||||
|
||||
# If your intent is to change the metadata of a Tensor (such as sizes / strides / storage / storage_offset)
|
||||
# without autograd tracking the change, remove the .data / .detach() call and wrap the change in a `with torch.no_grad():` block.
|
||||
# These ops cannot be unwrapped using .data
|
||||
_CHANGE_META_OPS = ['_cudnn_rnn_flatten_weight', 'requires_grad_', '__get__', '__set__', 'numel', 'size', 'dim']
|
||||
_CHANGE_META_OPS = ["_cudnn_rnn_flatten_weight", "requires_grad_", "__get__", "__set__", "numel", "size", "dim"]
|
||||
|
||||
_LEGACY_TENSOR_CONSTRUCTOR = {
|
||||
'FloatTensor': torch.float,
|
||||
'DoubleTensor': torch.double,
|
||||
'HalfTensor': torch.half,
|
||||
'BFloat16Tensor': torch.bfloat16,
|
||||
'ByteTensor': torch.uint8,
|
||||
'CharTensor': torch.int8,
|
||||
'ShortTensor': torch.short,
|
||||
'IntTensor': torch.int,
|
||||
'LongTensor': torch.long,
|
||||
'BoolTensor': torch.bool,
|
||||
"FloatTensor": torch.float,
|
||||
"DoubleTensor": torch.double,
|
||||
"HalfTensor": torch.half,
|
||||
"BFloat16Tensor": torch.bfloat16,
|
||||
"ByteTensor": torch.uint8,
|
||||
"CharTensor": torch.int8,
|
||||
"ShortTensor": torch.short,
|
||||
"IntTensor": torch.int,
|
||||
"LongTensor": torch.long,
|
||||
"BoolTensor": torch.bool,
|
||||
}
|
||||
|
||||
_EMPTY_DATA = torch.empty(0)
|
||||
|
||||
|
||||
class _MyTensor(Tensor):
|
||||
"""This class is only for correctness verification.
|
||||
"""
|
||||
_pre_op_fn: Callable[['LazyTensor'], None] = lambda *args: None
|
||||
"""This class is only for correctness verification."""
|
||||
|
||||
_pre_op_fn: Callable[["LazyTensor"], None] = lambda *args: None
|
||||
|
||||
default_device: Optional[torch.device] = None
|
||||
|
||||
def __new__(cls, func, *args, concrete_data=None, **kwargs) -> '_MyTensor':
|
||||
def __new__(cls, func, *args, concrete_data=None, **kwargs) -> "_MyTensor":
|
||||
cls._pre_op_fn()
|
||||
if concrete_data is not None:
|
||||
# uniform api as LazyTensor
|
||||
data = concrete_data
|
||||
else:
|
||||
kwargs['device'] = cls.default_device
|
||||
kwargs["device"] = cls.default_device
|
||||
data = func(*args, **kwargs)
|
||||
return Tensor._make_subclass(cls, data, require_grad=data.requires_grad)
|
||||
|
||||
@@ -82,12 +81,11 @@ class _MyTensor(Tensor):
|
||||
|
||||
|
||||
def _data_tolist(tensor: torch.Tensor) -> list:
|
||||
"""tolist() method is not allowed for a subclass of tensor. Tensor.data returns a Tensor.
|
||||
"""
|
||||
"""tolist() method is not allowed for a subclass of tensor. Tensor.data returns a Tensor."""
|
||||
return tensor.data.tolist()
|
||||
|
||||
|
||||
def _convert_cls(tensor: 'LazyTensor', target: torch.Tensor) -> torch.Tensor:
|
||||
def _convert_cls(tensor: "LazyTensor", target: torch.Tensor) -> torch.Tensor:
|
||||
"""Convert a lazy tensor's class to target's class, with target's data.
|
||||
|
||||
The reason why we change the class of a lazy tensor in-place is that this can easily handle shared modules/parameters, which is common in huggingface models.
|
||||
@@ -104,7 +102,7 @@ def _convert_cls(tensor: 'LazyTensor', target: torch.Tensor) -> torch.Tensor:
|
||||
tensor.__class__ = cls_to_become
|
||||
if cls_to_become is Parameter:
|
||||
# to fit UninitializedParameter
|
||||
delattr(tensor, '_is_param')
|
||||
delattr(tensor, "_is_param")
|
||||
tensor.data = target
|
||||
tensor.requires_grad = target.requires_grad
|
||||
# subclass of torch.Tensor does not have tolist() method
|
||||
@@ -147,8 +145,8 @@ class LazyTensor(torch.Tensor):
|
||||
"""
|
||||
|
||||
_repr = True
|
||||
_meta_data: Optional[MetaTensor] = None # shape, dtype, device
|
||||
_pre_op_fn: Callable[['LazyTensor'], None] = lambda *args: None
|
||||
_meta_data: Optional[MetaTensor] = None # shape, dtype, device
|
||||
_pre_op_fn: Callable[["LazyTensor"], None] = lambda *args: None
|
||||
|
||||
default_device: Optional[torch.device] = None
|
||||
|
||||
@@ -159,8 +157,8 @@ class LazyTensor(torch.Tensor):
|
||||
elem = concrete_data
|
||||
else:
|
||||
if meta_data is None:
|
||||
device = kwargs.get('device', 'cpu')
|
||||
elem = func(*args, **{**kwargs, 'device': 'meta'})
|
||||
device = kwargs.get("device", "cpu")
|
||||
elem = func(*args, **{**kwargs, "device": "meta"})
|
||||
meta_data = MetaTensor(elem, device=device)
|
||||
elem = meta_data._tensor
|
||||
# As a meta tensor cannot be modified __class__ to torch.Tensor, we should use an empty real tensor here
|
||||
@@ -170,10 +168,10 @@ class LazyTensor(torch.Tensor):
|
||||
|
||||
def __init__(self, func, *args, meta_data=None, concrete_data=None, **kwargs):
|
||||
if func.__name__ in _NORMAL_FACTORY:
|
||||
kwargs = {**kwargs, 'device': LazyTensor.default_device}
|
||||
self._factory_method = (func, args, kwargs) # (func, args, kwargs)
|
||||
self._op_buffer = [] # (func, args, kwargs, replace)
|
||||
self._materialized_data: Optional[torch.Tensor] = concrete_data # materialized data
|
||||
kwargs = {**kwargs, "device": LazyTensor.default_device}
|
||||
self._factory_method = (func, args, kwargs) # (func, args, kwargs)
|
||||
self._op_buffer = [] # (func, args, kwargs, replace)
|
||||
self._materialized_data: Optional[torch.Tensor] = concrete_data # materialized data
|
||||
|
||||
def materialize(self) -> torch.Tensor:
|
||||
"""Materialize the ``LazyTensor`` to ``torch.Tensor`` by modifying __class__ (inplace).
|
||||
@@ -200,12 +198,11 @@ class LazyTensor(torch.Tensor):
|
||||
return _convert_cls(self, local_tensor)
|
||||
|
||||
def clean(self) -> None:
|
||||
"""Clean all stored operations, meta data and materialized data, which prevents memory leaking. This should be called after all tensors are materialized.
|
||||
"""
|
||||
delattr(self, '_factory_method')
|
||||
delattr(self, '_op_buffer')
|
||||
delattr(self, '_materialized_data')
|
||||
delattr(self, '_meta_data')
|
||||
"""Clean all stored operations, meta data and materialized data, which prevents memory leaking. This should be called after all tensors are materialized."""
|
||||
delattr(self, "_factory_method")
|
||||
delattr(self, "_op_buffer")
|
||||
delattr(self, "_materialized_data")
|
||||
delattr(self, "_meta_data")
|
||||
|
||||
@staticmethod
|
||||
def _replace_with_materialized(x):
|
||||
@@ -221,8 +218,9 @@ class LazyTensor(torch.Tensor):
|
||||
# apply cached sequence
|
||||
self._pre_op_fn()
|
||||
|
||||
init_val = func(*tree_map(self._replace_with_materialized, args),
|
||||
**tree_map(self._replace_with_materialized, kwargs))
|
||||
init_val = func(
|
||||
*tree_map(self._replace_with_materialized, args), **tree_map(self._replace_with_materialized, kwargs)
|
||||
)
|
||||
|
||||
self._materialized_data = self._rerun_ops(init_val)
|
||||
return self._materialized_data
|
||||
@@ -243,13 +241,13 @@ class LazyTensor(torch.Tensor):
|
||||
|
||||
packed = None
|
||||
|
||||
for (func, args, kwargs) in self._op_buffer:
|
||||
for func, args, kwargs in self._op_buffer:
|
||||
if func == torch.Tensor.requires_grad_:
|
||||
packed = func, args, kwargs # requires grad should be set at last
|
||||
packed = func, args, kwargs # requires grad should be set at last
|
||||
else:
|
||||
self._pre_op_fn()
|
||||
o = func(*tree_map(replace, args), **tree_map(replace, kwargs))
|
||||
target = o if isinstance(o, torch.Tensor) else target # if func returns non-Tensor, discard the value
|
||||
target = o if isinstance(o, torch.Tensor) else target # if func returns non-Tensor, discard the value
|
||||
|
||||
# super-dainiu: set requires_grad after all inplace-ops are done
|
||||
if packed is not None:
|
||||
@@ -268,8 +266,11 @@ class LazyTensor(torch.Tensor):
|
||||
# These OPs cannot be lazy and related tensors should be early materialized
|
||||
tree_map(cls._replace_with_materialized, args)
|
||||
tree_map(cls._replace_with_materialized, kwargs)
|
||||
is_inplace: bool = (func.__name__.endswith('_') and not (func.__name__.endswith('__'))
|
||||
or func.__name__ in ('__setitem__', '__set__'))
|
||||
is_inplace: bool = (
|
||||
func.__name__.endswith("_")
|
||||
and not (func.__name__.endswith("__"))
|
||||
or func.__name__ in ("__setitem__", "__set__")
|
||||
)
|
||||
|
||||
is_change_meta_op: bool = func.__name__ in _CHANGE_META_OPS
|
||||
|
||||
@@ -285,11 +286,11 @@ class LazyTensor(torch.Tensor):
|
||||
|
||||
target: LazyTensor = args[0].clone()
|
||||
target._op_buffer.append((func, args, kwargs))
|
||||
target._meta_data = getattr(target._meta_data, func.name)(*tree_map(unwrap, args[1:]),
|
||||
**tree_map(unwrap, kwargs))
|
||||
target._meta_data = getattr(target._meta_data, func.name)(
|
||||
*tree_map(unwrap, args[1:]), **tree_map(unwrap, kwargs)
|
||||
)
|
||||
return target
|
||||
else:
|
||||
|
||||
meta_to_lazy = {}
|
||||
|
||||
def unwrap(x):
|
||||
@@ -328,10 +329,9 @@ class LazyTensor(torch.Tensor):
|
||||
|
||||
@classmethod
|
||||
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
||||
pass # skip
|
||||
pass # skip
|
||||
|
||||
def clone(self) -> "LazyTensor":
|
||||
|
||||
def factory_fn():
|
||||
# if self is materialized, return self
|
||||
new_tensor = self.materialize() if type(self) is LazyTensor else self
|
||||
@@ -346,8 +346,10 @@ class LazyTensor(torch.Tensor):
|
||||
|
||||
def __deepcopy__(self, memo):
|
||||
if not self.is_leaf:
|
||||
raise RuntimeError("Only Tensors created explicitly by the user "
|
||||
"(graph leaves) support the deepcopy protocol at the moment")
|
||||
raise RuntimeError(
|
||||
"Only Tensors created explicitly by the user "
|
||||
"(graph leaves) support the deepcopy protocol at the moment"
|
||||
)
|
||||
if id(self) in memo:
|
||||
return memo[id(self)]
|
||||
|
||||
@@ -375,7 +377,7 @@ class LazyTensor(torch.Tensor):
|
||||
return self
|
||||
|
||||
@data.setter
|
||||
def data(self, other: 'LazyTensor'):
|
||||
def data(self, other: "LazyTensor"):
|
||||
"""This is sightly different from oringinal `data` setter.
|
||||
|
||||
E.g.:
|
||||
@@ -413,7 +415,7 @@ class LazyTensor(torch.Tensor):
|
||||
|
||||
def __rpow__(self, other):
|
||||
dtype = torch.result_type(self, other)
|
||||
return torch.tensor(other, dtype=dtype, device=self.device)**self
|
||||
return torch.tensor(other, dtype=dtype, device=self.device) ** self
|
||||
|
||||
|
||||
class LazyInitContext:
|
||||
@@ -444,11 +446,14 @@ class LazyInitContext:
|
||||
1. Quantization strategies can be applied before allocating real memory.
|
||||
2. Lazy initialization seems slower than normal initialization.
|
||||
"""
|
||||
|
||||
_replaced: bool = False
|
||||
|
||||
def __init__(self,
|
||||
tensor_cls: Union[_MyTensor, LazyTensor] = LazyTensor,
|
||||
default_device: Optional[Union[torch.device, str, int]] = None):
|
||||
def __init__(
|
||||
self,
|
||||
tensor_cls: Union[_MyTensor, LazyTensor] = LazyTensor,
|
||||
default_device: Optional[Union[torch.device, str, int]] = None,
|
||||
):
|
||||
assert tensor_cls is LazyTensor or tensor_cls is _MyTensor
|
||||
self.overrides = {}
|
||||
self.tensor_cls = tensor_cls
|
||||
@@ -457,7 +462,7 @@ class LazyInitContext:
|
||||
|
||||
def __enter__(self):
|
||||
if LazyInitContext._replaced:
|
||||
raise RuntimeError(f'LazyInitContext is not reentrant')
|
||||
raise RuntimeError(f"LazyInitContext is not reentrant")
|
||||
LazyInitContext._replaced = True
|
||||
self.old_default_device = self.tensor_cls.default_device
|
||||
self.tensor_cls.default_device = self.default_device
|
||||
@@ -485,17 +490,17 @@ class LazyInitContext:
|
||||
return args[0]
|
||||
elif len(args) == 1:
|
||||
# (object data, *, torch.device device)
|
||||
kwargs = {**kwargs, 'dtype': dtype}
|
||||
replaced, orig = self.overrides['tensor']
|
||||
kwargs = {**kwargs, "dtype": dtype}
|
||||
replaced, orig = self.overrides["tensor"]
|
||||
return replaced(*args, **kwargs)
|
||||
elif _is_int_tuple(args):
|
||||
# (tuple of ints size, *, torch.device device)
|
||||
kwargs = {**kwargs, 'dtype': dtype}
|
||||
replaced, orig = self.overrides['empty']
|
||||
kwargs = {**kwargs, "dtype": dtype}
|
||||
replaced, orig = self.overrides["empty"]
|
||||
return replaced(*args, **kwargs)
|
||||
else:
|
||||
raise TypeError(
|
||||
f'new() received an invalid combination of arguments - got {tuple(type(x) for x in args)}, but expected one of:\n * (Tensor other)\n * (tuple of ints size, *, torch.device device)\n * (object data, *, torch.device device)'
|
||||
f"new() received an invalid combination of arguments - got {tuple(type(x) for x in args)}, but expected one of:\n * (Tensor other)\n * (tuple of ints size, *, torch.device device)\n * (object data, *, torch.device device)"
|
||||
)
|
||||
|
||||
return wrapper, target
|
||||
@@ -514,23 +519,29 @@ class LazyInitContext:
|
||||
if callable(getattr(torch, target, None))
|
||||
}
|
||||
|
||||
self.overrides.update({
|
||||
target + '_like': wrap_factory_like_method(getattr(torch, target), getattr(torch, target + '_like'))
|
||||
for target in _NORMAL_FACTORY
|
||||
if callable(getattr(torch, target + '_like', None))
|
||||
})
|
||||
self.overrides.update(
|
||||
{
|
||||
target + "_like": wrap_factory_like_method(getattr(torch, target), getattr(torch, target + "_like"))
|
||||
for target in _NORMAL_FACTORY
|
||||
if callable(getattr(torch, target + "_like", None))
|
||||
}
|
||||
)
|
||||
|
||||
self.overrides.update({
|
||||
target: wrap_legacy_constructor(getattr(torch, target), dtype)
|
||||
for target, dtype in _LEGACY_TENSOR_CONSTRUCTOR.items()
|
||||
if callable(getattr(torch, target, None))
|
||||
})
|
||||
self.overrides.update(
|
||||
{
|
||||
target: wrap_legacy_constructor(getattr(torch, target), dtype)
|
||||
for target, dtype in _LEGACY_TENSOR_CONSTRUCTOR.items()
|
||||
if callable(getattr(torch, target, None))
|
||||
}
|
||||
)
|
||||
|
||||
self.overrides.update({
|
||||
target: wrap_no_meta_factory(getattr(torch, target))
|
||||
for target in _NO_META_FACTORY
|
||||
if callable(getattr(torch, target, None))
|
||||
})
|
||||
self.overrides.update(
|
||||
{
|
||||
target: wrap_no_meta_factory(getattr(torch, target))
|
||||
for target in _NO_META_FACTORY
|
||||
if callable(getattr(torch, target, None))
|
||||
}
|
||||
)
|
||||
|
||||
for name, (wrapper, orig) in self.overrides.items():
|
||||
setattr(torch, name, wrapper)
|
||||
@@ -556,10 +567,9 @@ class LazyInitContext:
|
||||
return _apply_to_lazy_module(module, apply_fn, verbose)
|
||||
|
||||
@staticmethod
|
||||
def distribute(module: nn.Module,
|
||||
device_mesh: DeviceMesh,
|
||||
sharding_spec_dict: Dict[str, ShardingSpec],
|
||||
verbose: bool = False) -> nn.Module:
|
||||
def distribute(
|
||||
module: nn.Module, device_mesh: DeviceMesh, sharding_spec_dict: Dict[str, ShardingSpec], verbose: bool = False
|
||||
) -> nn.Module:
|
||||
"""Distribute all ``Parameter`` from ``LazyTensor``. This function will modify the module in-place.
|
||||
|
||||
Args:
|
||||
@@ -574,9 +584,9 @@ class LazyInitContext:
|
||||
return _apply_to_lazy_module(module, apply_fn, verbose)
|
||||
|
||||
|
||||
def _apply_to_lazy_module(module: nn.Module,
|
||||
apply_fn: Callable[[str, torch.Tensor], None],
|
||||
verbose: bool = False) -> nn.Module:
|
||||
def _apply_to_lazy_module(
|
||||
module: nn.Module, apply_fn: Callable[[str, torch.Tensor], None], verbose: bool = False
|
||||
) -> nn.Module:
|
||||
if verbose:
|
||||
# verbose info
|
||||
param_cnt = 0
|
||||
@@ -590,7 +600,7 @@ def _apply_to_lazy_module(module: nn.Module,
|
||||
if verbose:
|
||||
param_cnt += 1
|
||||
total_numel += p.numel()
|
||||
if getattr(p, '_materialized_data', False) is None:
|
||||
if getattr(p, "_materialized_data", False) is None:
|
||||
# if no _materialized_data attr, the tensor is not lazy
|
||||
param_lazy_cnt += 1
|
||||
else:
|
||||
@@ -612,10 +622,11 @@ def _apply_to_lazy_module(module: nn.Module,
|
||||
|
||||
if verbose:
|
||||
non_lazy_numel_ratio = non_lazy_numel / total_numel * 100 if non_lazy_numel != 0 else 0
|
||||
_print_rank_0(f'Param lazy rate: {param_lazy_cnt}/{param_cnt}')
|
||||
_print_rank_0(f'Buffer lazy rate: {buf_lazy_cnt}/{buf_cnt}')
|
||||
_print_rank_0(f"Param lazy rate: {param_lazy_cnt}/{param_cnt}")
|
||||
_print_rank_0(f"Buffer lazy rate: {buf_lazy_cnt}/{buf_cnt}")
|
||||
_print_rank_0(
|
||||
f'Non lazy numel: {non_lazy_numel} ({non_lazy_numel/1024**2:.3f} M), ratio: {non_lazy_numel_ratio}%')
|
||||
f"Non lazy numel: {non_lazy_numel} ({non_lazy_numel/1024**2:.3f} M), ratio: {non_lazy_numel_ratio}%"
|
||||
)
|
||||
|
||||
return module
|
||||
|
||||
|
Reference in New Issue
Block a user