[Tensor] add Parameter inheritance for ColoParameter (#1041)

* add Parameter inheritance for ColoParameter

* remove tricks

* remove tricks

* polish

* polish
This commit is contained in:
Ziyue Jiang 2022-05-30 17:23:44 +08:00 committed by GitHub
parent 4d8a574cd3
commit 7c530b9de2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 59 additions and 100 deletions

View File

@ -3,15 +3,15 @@ from .const import TensorType
import torch import torch
from colossalai.tensor import TensorSpec, distspec from colossalai.tensor import TensorSpec, distspec
from copy import copy from copy import copy
from typing import Optional
class ColoParameter(ColoTensor, torch.nn.Parameter):
class ColoParameter(ColoTensor):
r"""A kind of ColoTensor to be considered as a module parameter. r"""A kind of ColoTensor to be considered as a module parameter.
""" """
def __new__(cls, def __new__(cls,
data: torch.Tensor, data: Optional[torch.Tensor] = None,
requires_grad: bool = True, requires_grad: bool = True,
spec: TensorSpec = TensorSpec(distspec.replicate())) -> 'ColoParameter': spec: TensorSpec = TensorSpec(distspec.replicate())) -> 'ColoParameter':
if data is None: if data is None:
@ -19,7 +19,7 @@ class ColoParameter(ColoTensor):
return torch.Tensor._make_subclass(cls, data, requires_grad) return torch.Tensor._make_subclass(cls, data, requires_grad)
def __init__(self, def __init__(self,
data: torch.Tensor, data: Optional[torch.Tensor] = None,
requires_grad: bool = True, requires_grad: bool = True,
spec: TensorSpec = TensorSpec(distspec.replicate())) -> None: spec: TensorSpec = TensorSpec(distspec.replicate())) -> None:
self._spec = copy(spec) self._spec = copy(spec)
@ -44,3 +44,29 @@ class ColoParameter(ColoTensor):
def __repr__(self): def __repr__(self):
return f'ColoParameter: {torch.Tensor.__repr__(self)}' return f'ColoParameter: {torch.Tensor.__repr__(self)}'
def __deepcopy__(self, memo):
if id(self) in memo:
return memo[id(self)]
else:
with torch._C.DisableTorchFunction():
data = self.data.clone()
tensor = ColoParameter(data, self.requires_grad, spec=copy(self.spec))
memo[id(self)] = tensor
return tensor
def __reduce_ex__(self, proto):
# Adapted from torch._utils._rebuild_parameter
# def _rebuild_colo_parameter(data, requires_grad, backward_hooks):
# colo_param = ColoParameter(data, requires_grad)
# colo_param._backward_hooks = backward_hooks
# return colo_param
# return (
# _rebuild_colo_parameter,
# (self.data, self.requires_grad, OrderedDict())
# )
# TODO(jzy) we don't support object reflection now.
# distspec cannot be pickled or rebuilt because it's tightly connected to runtime attribute `process_group`.
raise NotImplementedError

View File

@ -24,96 +24,6 @@ def _named_params_with_replica(
name = mod_prefix + ('.' if mod_prefix else '') + name name = mod_prefix + ('.' if mod_prefix else '') + name
yield name, val yield name, val
# Adapted from torch.nn.module.Module.register_param
def _register_parameter_with_colotensor(self, name: str, param):
if '_parameters' not in self.__dict__:
raise AttributeError("cannot assign parameter before Module.__init__() call")
if not isinstance(name, torch._six.string_classes):
raise TypeError("parameter name should be a string. "
"Got {}".format(torch.typename(name)))
if '.' in name:
raise KeyError("parameter name can't contain \".\"")
if name == '':
raise KeyError("parameter name can't be empty string \"\"")
if hasattr(self, name) and name not in self._parameters:
raise KeyError("attribute '{}' already exists".format(name))
if param is None:
self._parameters[name] = None
elif not isinstance(param, (torch.nn.Parameter, ColoParameter)):
raise TypeError("cannot assign '{}' object to parameter '{}' "
"(torch.nn.Parameter or ColoParameter or None required)".format(torch.typename(param), name))
elif param.grad_fn:
raise ValueError("Cannot assign non-leaf Tensor to parameter '{0}'. Model "
"parameters must be created explicitly. To express '{0}' "
"as a function of another Tensor, compute the value in "
"the forward() method.".format(name))
else:
self._parameters[name] = param
# Adapted from torch.nn.module.Module.__setattr__
def _setattr_with_colotensor(self, name: str, value: Union[torch.Tensor, torch.nn.Module, ColoTensor]):
def remove_from(*dicts_or_sets):
for d in dicts_or_sets:
if name in d:
if isinstance(d, dict):
del d[name]
else:
d.discard(name)
params = self.__dict__.get('_parameters')
if isinstance(value, (ColoParameter, torch.nn.Parameter)):
if params is None:
raise AttributeError("cannot assign parameters before Module.__init__() call")
remove_from(self.__dict__, self._buffers, self._modules, self._non_persistent_buffers_set)
self.register_parameter(name, value)
elif params is not None and name in params:
if value is not None:
raise TypeError("cannot assign '{}' as parameter '{}' "
"(torch.nn.Parameter or None expected)".format(torch.typename(value), name))
self.register_parameter(name, value)
else:
modules = self.__dict__.get('_modules')
if isinstance(value, torch.nn.Module):
if modules is None:
raise AttributeError("cannot assign module before Module.__init__() call")
remove_from(self.__dict__, self._parameters, self._buffers, self._non_persistent_buffers_set)
modules[name] = value
elif modules is not None and name in modules:
if value is not None:
raise TypeError("cannot assign '{}' as child module '{}' "
"(torch.nn.Module or None expected)".format(torch.typename(value), name))
modules[name] = value
else:
buffers = self.__dict__.get('_buffers')
if buffers is not None and name in buffers:
if value is not None and not isinstance(value, torch.Tensor):
raise TypeError("cannot assign '{}' as buffer '{}' "
"(torch.Tensor or None expected)".format(torch.typename(value), name))
buffers[name] = value
else:
object.__setattr__(self, name, value)
def _get_parameter_with_colotensor(self, target: str) -> Union[torch.nn.Parameter, ColoTensor]:
module_path, _, param_name = target.rpartition(".")
mod: torch.nn.Module = self.get_submodule(module_path)
if not hasattr(mod, param_name):
raise AttributeError(mod._get_name() + " has no attribute `"
+ param_name + "`")
param = getattr(mod, param_name)
return param
def ColoModulize(module): def ColoModulize(module):
""" """
Replacing the parameters() and named_parameters() with our customized ones Replacing the parameters() and named_parameters() with our customized ones
@ -134,10 +44,6 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
self._lazy_memory_allocate = lazy_memory_allocate self._lazy_memory_allocate = lazy_memory_allocate
self._device = device self._device = device
torch.nn.Module.__setattr__ = _setattr_with_colotensor
torch.nn.Module.register_parameter = _register_parameter_with_colotensor
torch.nn.Module.get_parameter = _get_parameter_with_colotensor
self._register_colo_modules() self._register_colo_modules()
def _register_colo_modules(self): def _register_colo_modules(self):

View File

@ -353,5 +353,5 @@ def _test_pretrain_load(world_size):
if __name__ == '__main__': if __name__ == '__main__':
# test_model_parameters() # test_model_parameters()
# test_colo_optimizer() # test_colo_optimizer()
test_model(4) # test_model(4)
# _test_pretrain_load(4) _test_pretrain_load(4)

View File

@ -0,0 +1,26 @@
from colossalai.tensor import ColoParameter, ColoTensor
import torch
from numpy import allclose
from _utils import tensor_equal
def test_multiinheritance():
colo_param = ColoParameter()
assert isinstance(colo_param, ColoTensor)
assert isinstance(colo_param, torch.nn.Parameter)
# __deepcopy__ overload
import copy
colo_param2 = copy.deepcopy(colo_param)
assert isinstance(colo_param2, ColoParameter)
assert tensor_equal(colo_param.data, colo_param2.data)
assert colo_param.requires_grad == colo_param2.requires_grad
# __repr__ overload
assert 'ColoParameter' in str(colo_param)
# __torch_function__
clone_param = torch.clone(colo_param)
assert isinstance(clone_param, ColoTensor)
if __name__ == '__main__':
test_multiinheritance()

View File

@ -46,3 +46,4 @@ def test_operand():
t_ref_res = t_ref + t_ref t_ref_res = t_ref + t_ref
t_res = t + t t_res = t + t
assert torch.allclose(t_ref_res, t_res) assert torch.allclose(t_ref_res, t_res)