diff --git a/colossalai/tensor/colo_parameter.py b/colossalai/tensor/colo_parameter.py index 8d99a6a02..948356914 100644 --- a/colossalai/tensor/colo_parameter.py +++ b/colossalai/tensor/colo_parameter.py @@ -3,15 +3,15 @@ from .const import TensorType import torch from colossalai.tensor import TensorSpec, distspec from copy import copy +from typing import Optional - -class ColoParameter(ColoTensor): +class ColoParameter(ColoTensor, torch.nn.Parameter): r"""A kind of ColoTensor to be considered as a module parameter. """ def __new__(cls, - data: torch.Tensor, + data: Optional[torch.Tensor] = None, requires_grad: bool = True, spec: TensorSpec = TensorSpec(distspec.replicate())) -> 'ColoParameter': if data is None: @@ -19,7 +19,7 @@ class ColoParameter(ColoTensor): return torch.Tensor._make_subclass(cls, data, requires_grad) def __init__(self, - data: torch.Tensor, + data: Optional[torch.Tensor] = None, requires_grad: bool = True, spec: TensorSpec = TensorSpec(distspec.replicate())) -> None: self._spec = copy(spec) @@ -43,4 +43,30 @@ class ColoParameter(ColoTensor): def __repr__(self): return f'ColoParameter: {torch.Tensor.__repr__(self)}' + + def __deepcopy__(self, memo): + if id(self) in memo: + return memo[id(self)] + else: + with torch._C.DisableTorchFunction(): + data = self.data.clone() + tensor = ColoParameter(data, self.requires_grad, spec=copy(self.spec)) + memo[id(self)] = tensor + return tensor + + def __reduce_ex__(self, proto): + # Adapted from torch._utils._rebuild_parameter + # def _rebuild_colo_parameter(data, requires_grad, backward_hooks): + # colo_param = ColoParameter(data, requires_grad) + # colo_param._backward_hooks = backward_hooks + # return colo_param + + # return ( + # _rebuild_colo_parameter, + # (self.data, self.requires_grad, OrderedDict()) + # ) + + # TODO(jzy) we don't support object reflection now. + # distspec cannot be pickled or rebuilt because it's tightly connected to runtime attribute `process_group`. + raise NotImplementedError diff --git a/colossalai/utils/model/colo_init_context.py b/colossalai/utils/model/colo_init_context.py index 0f87cc289..058a5fb65 100644 --- a/colossalai/utils/model/colo_init_context.py +++ b/colossalai/utils/model/colo_init_context.py @@ -24,96 +24,6 @@ def _named_params_with_replica( name = mod_prefix + ('.' if mod_prefix else '') + name yield name, val - -# Adapted from torch.nn.module.Module.register_param - - -def _register_parameter_with_colotensor(self, name: str, param): - if '_parameters' not in self.__dict__: - raise AttributeError("cannot assign parameter before Module.__init__() call") - - if not isinstance(name, torch._six.string_classes): - raise TypeError("parameter name should be a string. " - "Got {}".format(torch.typename(name))) - if '.' in name: - raise KeyError("parameter name can't contain \".\"") - if name == '': - raise KeyError("parameter name can't be empty string \"\"") - if hasattr(self, name) and name not in self._parameters: - raise KeyError("attribute '{}' already exists".format(name)) - - if param is None: - self._parameters[name] = None - elif not isinstance(param, (torch.nn.Parameter, ColoParameter)): - raise TypeError("cannot assign '{}' object to parameter '{}' " - "(torch.nn.Parameter or ColoParameter or None required)".format(torch.typename(param), name)) - elif param.grad_fn: - raise ValueError("Cannot assign non-leaf Tensor to parameter '{0}'. Model " - "parameters must be created explicitly. To express '{0}' " - "as a function of another Tensor, compute the value in " - "the forward() method.".format(name)) - else: - self._parameters[name] = param - - -# Adapted from torch.nn.module.Module.__setattr__ - - -def _setattr_with_colotensor(self, name: str, value: Union[torch.Tensor, torch.nn.Module, ColoTensor]): - - def remove_from(*dicts_or_sets): - for d in dicts_or_sets: - if name in d: - if isinstance(d, dict): - del d[name] - else: - d.discard(name) - - params = self.__dict__.get('_parameters') - if isinstance(value, (ColoParameter, torch.nn.Parameter)): - if params is None: - raise AttributeError("cannot assign parameters before Module.__init__() call") - remove_from(self.__dict__, self._buffers, self._modules, self._non_persistent_buffers_set) - self.register_parameter(name, value) - elif params is not None and name in params: - if value is not None: - raise TypeError("cannot assign '{}' as parameter '{}' " - "(torch.nn.Parameter or None expected)".format(torch.typename(value), name)) - self.register_parameter(name, value) - else: - modules = self.__dict__.get('_modules') - if isinstance(value, torch.nn.Module): - if modules is None: - raise AttributeError("cannot assign module before Module.__init__() call") - remove_from(self.__dict__, self._parameters, self._buffers, self._non_persistent_buffers_set) - modules[name] = value - elif modules is not None and name in modules: - if value is not None: - raise TypeError("cannot assign '{}' as child module '{}' " - "(torch.nn.Module or None expected)".format(torch.typename(value), name)) - modules[name] = value - else: - buffers = self.__dict__.get('_buffers') - if buffers is not None and name in buffers: - if value is not None and not isinstance(value, torch.Tensor): - raise TypeError("cannot assign '{}' as buffer '{}' " - "(torch.Tensor or None expected)".format(torch.typename(value), name)) - buffers[name] = value - else: - object.__setattr__(self, name, value) - -def _get_parameter_with_colotensor(self, target: str) -> Union[torch.nn.Parameter, ColoTensor]: - module_path, _, param_name = target.rpartition(".") - - mod: torch.nn.Module = self.get_submodule(module_path) - - if not hasattr(mod, param_name): - raise AttributeError(mod._get_name() + " has no attribute `" - + param_name + "`") - - param = getattr(mod, param_name) - return param - def ColoModulize(module): """ Replacing the parameters() and named_parameters() with our customized ones @@ -134,10 +44,6 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses): self._lazy_memory_allocate = lazy_memory_allocate self._device = device - torch.nn.Module.__setattr__ = _setattr_with_colotensor - torch.nn.Module.register_parameter = _register_parameter_with_colotensor - torch.nn.Module.get_parameter = _get_parameter_with_colotensor - self._register_colo_modules() def _register_colo_modules(self): diff --git a/tests/test_tensor/test_model.py b/tests/test_tensor/test_model.py index c9e3da884..682146e1a 100644 --- a/tests/test_tensor/test_model.py +++ b/tests/test_tensor/test_model.py @@ -353,5 +353,5 @@ def _test_pretrain_load(world_size): if __name__ == '__main__': # test_model_parameters() # test_colo_optimizer() - test_model(4) - # _test_pretrain_load(4) + # test_model(4) + _test_pretrain_load(4) diff --git a/tests/test_tensor/test_parameter.py b/tests/test_tensor/test_parameter.py new file mode 100644 index 000000000..a5c0b15a3 --- /dev/null +++ b/tests/test_tensor/test_parameter.py @@ -0,0 +1,26 @@ +from colossalai.tensor import ColoParameter, ColoTensor +import torch +from numpy import allclose +from _utils import tensor_equal + +def test_multiinheritance(): + colo_param = ColoParameter() + assert isinstance(colo_param, ColoTensor) + assert isinstance(colo_param, torch.nn.Parameter) + + # __deepcopy__ overload + import copy + colo_param2 = copy.deepcopy(colo_param) + assert isinstance(colo_param2, ColoParameter) + assert tensor_equal(colo_param.data, colo_param2.data) + assert colo_param.requires_grad == colo_param2.requires_grad + + # __repr__ overload + assert 'ColoParameter' in str(colo_param) + + # __torch_function__ + clone_param = torch.clone(colo_param) + assert isinstance(clone_param, ColoTensor) + +if __name__ == '__main__': + test_multiinheritance() \ No newline at end of file diff --git a/tests/test_tensor/test_tensor.py b/tests/test_tensor/test_tensor.py index 07df7cdde..96ea93487 100644 --- a/tests/test_tensor/test_tensor.py +++ b/tests/test_tensor/test_tensor.py @@ -46,3 +46,4 @@ def test_operand(): t_ref_res = t_ref + t_ref t_res = t + t assert torch.allclose(t_ref_res, t_res) +