[tensor] ColoTensor supports ZeRo (#1015)

* impl chunk manager

* impl param op hook

* add reduce_chunk

* add zero hook v2

* add zero dp

* fix TensorInfo

* impl load balancing when using zero without chunk

* fix zero hook

* polish chunk

* fix bugs

* ddp ok

* zero ok

* polish code

* fix bugs about load balancing

* polish code

* polish code

* add ene-to-end test

* polish code

* polish code

* polish code

* fix typo

* add test_chunk

* fix bugs

* fix bugs

* polish code
This commit is contained in:
ver217
2022-05-31 12:00:12 +08:00
committed by GitHub
parent cfa6c1b46b
commit 9492a561c3
8 changed files with 618 additions and 4 deletions

View File

@@ -3,8 +3,10 @@ from .const import TensorType
import torch
from colossalai.tensor import TensorSpec, distspec
from copy import copy
from .param_op_hook import _ParamOpHookWrapper, PreFwdPostBwd, PostFwdPreBwd
from typing import Optional
class ColoParameter(ColoTensor, torch.nn.Parameter):
r"""A kind of ColoTensor to be considered as a module parameter.
@@ -44,6 +46,22 @@ class ColoParameter(ColoTensor, torch.nn.Parameter):
def __repr__(self):
return f'ColoParameter: {torch.Tensor.__repr__(self)}'
@classmethod
def __torch_function__(cls, func, types, args=..., kwargs=None):
if len(_ParamOpHookWrapper.hooks) > 0:
if not func.__name__.startswith('__'):
params = list(filter(lambda arg: isinstance(arg, ColoParameter), args))
if kwargs is not None:
params.extend(list(filter(lambda arg: isinstance(arg, ColoParameter), kwargs.values())))
if len(params) > 0:
with torch._C.DisableTorchFunction():
args = PreFwdPostBwd.apply(params, *args)
ret = super().__torch_function__(func, types, args, kwargs)
with torch._C.DisableTorchFunction():
ret = PostFwdPreBwd.apply(params, ret)
return ret
return super().__torch_function__(func, types, args, kwargs)
def __deepcopy__(self, memo):
if id(self) in memo:
return memo[id(self)]
@@ -69,4 +87,3 @@ class ColoParameter(ColoTensor, torch.nn.Parameter):
# TODO(jzy) we don't support object reflection now.
# distspec cannot be pickled or rebuilt because it's tightly connected to runtime attribute `process_group`.
raise NotImplementedError