mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 11:32:10 +00:00
[tensor] ColoTensor supports ZeRo (#1015)
* impl chunk manager * impl param op hook * add reduce_chunk * add zero hook v2 * add zero dp * fix TensorInfo * impl load balancing when using zero without chunk * fix zero hook * polish chunk * fix bugs * ddp ok * zero ok * polish code * fix bugs about load balancing * polish code * polish code * add ene-to-end test * polish code * polish code * polish code * fix typo * add test_chunk * fix bugs * fix bugs * polish code
This commit is contained in:
@@ -3,8 +3,10 @@ from .const import TensorType
|
||||
import torch
|
||||
from colossalai.tensor import TensorSpec, distspec
|
||||
from copy import copy
|
||||
from .param_op_hook import _ParamOpHookWrapper, PreFwdPostBwd, PostFwdPreBwd
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class ColoParameter(ColoTensor, torch.nn.Parameter):
|
||||
r"""A kind of ColoTensor to be considered as a module parameter.
|
||||
|
||||
@@ -44,6 +46,22 @@ class ColoParameter(ColoTensor, torch.nn.Parameter):
|
||||
def __repr__(self):
|
||||
return f'ColoParameter: {torch.Tensor.__repr__(self)}'
|
||||
|
||||
@classmethod
|
||||
def __torch_function__(cls, func, types, args=..., kwargs=None):
|
||||
if len(_ParamOpHookWrapper.hooks) > 0:
|
||||
if not func.__name__.startswith('__'):
|
||||
params = list(filter(lambda arg: isinstance(arg, ColoParameter), args))
|
||||
if kwargs is not None:
|
||||
params.extend(list(filter(lambda arg: isinstance(arg, ColoParameter), kwargs.values())))
|
||||
if len(params) > 0:
|
||||
with torch._C.DisableTorchFunction():
|
||||
args = PreFwdPostBwd.apply(params, *args)
|
||||
ret = super().__torch_function__(func, types, args, kwargs)
|
||||
with torch._C.DisableTorchFunction():
|
||||
ret = PostFwdPreBwd.apply(params, ret)
|
||||
return ret
|
||||
return super().__torch_function__(func, types, args, kwargs)
|
||||
|
||||
def __deepcopy__(self, memo):
|
||||
if id(self) in memo:
|
||||
return memo[id(self)]
|
||||
@@ -69,4 +87,3 @@ class ColoParameter(ColoTensor, torch.nn.Parameter):
|
||||
# TODO(jzy) we don't support object reflection now.
|
||||
# distspec cannot be pickled or rebuilt because it's tightly connected to runtime attribute `process_group`.
|
||||
raise NotImplementedError
|
||||
|
||||
|
Reference in New Issue
Block a user