mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 20:10:17 +00:00
[Tensor] add Parameter inheritance for ColoParameter (#1041)
* add Parameter inheritance for ColoParameter * remove tricks * remove tricks * polish * polish
This commit is contained in:
@@ -3,15 +3,15 @@ from .const import TensorType
|
||||
import torch
|
||||
from colossalai.tensor import TensorSpec, distspec
|
||||
from copy import copy
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class ColoParameter(ColoTensor):
|
||||
class ColoParameter(ColoTensor, torch.nn.Parameter):
|
||||
r"""A kind of ColoTensor to be considered as a module parameter.
|
||||
|
||||
"""
|
||||
|
||||
def __new__(cls,
|
||||
data: torch.Tensor,
|
||||
data: Optional[torch.Tensor] = None,
|
||||
requires_grad: bool = True,
|
||||
spec: TensorSpec = TensorSpec(distspec.replicate())) -> 'ColoParameter':
|
||||
if data is None:
|
||||
@@ -19,7 +19,7 @@ class ColoParameter(ColoTensor):
|
||||
return torch.Tensor._make_subclass(cls, data, requires_grad)
|
||||
|
||||
def __init__(self,
|
||||
data: torch.Tensor,
|
||||
data: Optional[torch.Tensor] = None,
|
||||
requires_grad: bool = True,
|
||||
spec: TensorSpec = TensorSpec(distspec.replicate())) -> None:
|
||||
self._spec = copy(spec)
|
||||
@@ -43,4 +43,30 @@ class ColoParameter(ColoTensor):
|
||||
|
||||
def __repr__(self):
|
||||
return f'ColoParameter: {torch.Tensor.__repr__(self)}'
|
||||
|
||||
def __deepcopy__(self, memo):
|
||||
if id(self) in memo:
|
||||
return memo[id(self)]
|
||||
else:
|
||||
with torch._C.DisableTorchFunction():
|
||||
data = self.data.clone()
|
||||
tensor = ColoParameter(data, self.requires_grad, spec=copy(self.spec))
|
||||
memo[id(self)] = tensor
|
||||
return tensor
|
||||
|
||||
def __reduce_ex__(self, proto):
|
||||
# Adapted from torch._utils._rebuild_parameter
|
||||
# def _rebuild_colo_parameter(data, requires_grad, backward_hooks):
|
||||
# colo_param = ColoParameter(data, requires_grad)
|
||||
# colo_param._backward_hooks = backward_hooks
|
||||
# return colo_param
|
||||
|
||||
# return (
|
||||
# _rebuild_colo_parameter,
|
||||
# (self.data, self.requires_grad, OrderedDict())
|
||||
# )
|
||||
|
||||
# TODO(jzy) we don't support object reflection now.
|
||||
# distspec cannot be pickled or rebuilt because it's tightly connected to runtime attribute `process_group`.
|
||||
raise NotImplementedError
|
||||
|
||||
|
Reference in New Issue
Block a user