[tensor] refactor colo-tensor (#992)

* refactor colo-tensor and update linear op

* polish code

* polish code

* update ops and unit tests

* update unit tests

* polish code

* rename dist_spec module

* polish code

* polish code

* remove unneeded import

* fix pipelinable
This commit is contained in:
ver217
2022-05-19 12:44:59 +08:00
committed by GitHub
parent 1467d83edf
commit ad536e308e
27 changed files with 657 additions and 616 deletions

View File

@@ -1,6 +1,8 @@
from .colo_tensor import ColoTensor
from .const import TensorType
import torch
from colossalai.tensor import TensorSpec, distspec
from copy import copy
class ColoParameter(ColoTensor):
@@ -8,21 +10,26 @@ class ColoParameter(ColoTensor):
"""
def __init__(self, *args, **kargs):
super().__init__(*args, **kargs)
self._type = TensorType.MODEL
def __new__(cls,
data: torch.Tensor,
requires_grad: bool = True,
spec: TensorSpec = TensorSpec(distspec.replicate())) -> 'ColoParameter':
if data is None:
data = torch.empty(0)
return torch.Tensor._make_subclass(cls, data, requires_grad)
def __new__(cls, *args, **kwargs):
t = super(ColoParameter, cls).__new__(cls)
t._type = TensorType.MODEL
return t
def __init__(self,
data: torch.Tensor,
requires_grad: bool = True,
spec: TensorSpec = TensorSpec(distspec.replicate())) -> None:
self._spec = copy(spec)
self._type = TensorType.MODEL
self._graph_node = None
@staticmethod
def init_from_torch_tensor(tensor: torch.Tensor, save_payload=True) -> 'ColoParameter':
colo_p = ColoParameter(*tensor.size(),
dtype=tensor.dtype,
requires_grad=tensor.requires_grad,
pin_memory=tensor.is_pinned(),
device=tensor.device,
torch_tensor=tensor if save_payload else torch.empty(0))
return colo_p
def from_torch_tensor(tensor: torch.Tensor,
requires_grad: bool = True,
spec: TensorSpec = TensorSpec(distspec.replicate())) -> 'ColoParameter':
tensor = tensor.as_subclass(ColoParameter)
tensor.__init__(tensor, requires_grad=requires_grad, spec=spec)
return tensor