[Tensor] add Parameter inheritance for ColoParameter (#1041)

* add Parameter inheritance for ColoParameter

* remove tricks

* remove tricks

* polish

* polish
This commit is contained in:
Ziyue Jiang
2022-05-30 17:23:44 +08:00
committed by GitHub
parent 4d8a574cd3
commit 7c530b9de2
5 changed files with 59 additions and 100 deletions

View File

@@ -3,15 +3,15 @@ from .const import TensorType
import torch
from colossalai.tensor import TensorSpec, distspec
from copy import copy
from typing import Optional
class ColoParameter(ColoTensor):
class ColoParameter(ColoTensor, torch.nn.Parameter):
r"""A kind of ColoTensor to be considered as a module parameter.
"""
def __new__(cls,
data: torch.Tensor,
data: Optional[torch.Tensor] = None,
requires_grad: bool = True,
spec: TensorSpec = TensorSpec(distspec.replicate())) -> 'ColoParameter':
if data is None:
@@ -19,7 +19,7 @@ class ColoParameter(ColoTensor):
return torch.Tensor._make_subclass(cls, data, requires_grad)
def __init__(self,
data: torch.Tensor,
data: Optional[torch.Tensor] = None,
requires_grad: bool = True,
spec: TensorSpec = TensorSpec(distspec.replicate())) -> None:
self._spec = copy(spec)
@@ -43,4 +43,30 @@ class ColoParameter(ColoTensor):
def __repr__(self):
return f'ColoParameter: {torch.Tensor.__repr__(self)}'
def __deepcopy__(self, memo):
if id(self) in memo:
return memo[id(self)]
else:
with torch._C.DisableTorchFunction():
data = self.data.clone()
tensor = ColoParameter(data, self.requires_grad, spec=copy(self.spec))
memo[id(self)] = tensor
return tensor
def __reduce_ex__(self, proto):
# Adapted from torch._utils._rebuild_parameter
# def _rebuild_colo_parameter(data, requires_grad, backward_hooks):
# colo_param = ColoParameter(data, requires_grad)
# colo_param._backward_hooks = backward_hooks
# return colo_param
# return (
# _rebuild_colo_parameter,
# (self.data, self.requires_grad, OrderedDict())
# )
# TODO(jzy) we don't support object reflection now.
# distspec cannot be pickled or rebuilt because it's tightly connected to runtime attribute `process_group`.
raise NotImplementedError