[ColoTensor] improves init functions. (#1150)

This commit is contained in:
Jiarui Fang
2022-06-21 18:28:38 +08:00
committed by GitHub
parent 8106d7b8c7
commit 8cdce0399c
5 changed files with 103 additions and 40 deletions

View File

@@ -1,12 +1,13 @@
from .op_wrapper import _COLOSSAL_OPS
from .const import TensorType
from copy import copy
import torch
from torch.overrides import get_default_nowrap_functions
from colossalai.tensor import TensorSpec
from .const import TensorType
from colossalai.tensor import distspec
from colossalai.tensor.dist_spec_mgr import DistSpecManager
from colossalai.tensor.distspec import _DistSpec
from torch.overrides import get_default_nowrap_functions
def _convert_output(output):
@@ -18,34 +19,54 @@ def _convert_output(output):
class ColoTensor(torch.Tensor):
""" Data Structure for Tensor in Colossal-AI
1. It contains a torch.Tensor as an attribute.
2. It supports lazy init the tensor's payload.
3. It can hijack the torch functions which using ColoTensors as args to our customized functions.
4. It supports distributing the tensor's payload to the shards among processes. (TODO)
""" Data Structure for Tensor in Colossal-AI. It is a subclass of torch.Tensor.
Args:
data (torch.Tensor): a torch tensor used as the payload the colotensor.
spec (TensorSpec, optional): the tensor spec of initialization. Defaults to TensorSpec(distspec.replicate()).
The signature of the function has to be consistent with the __new__ except for the 1st arg.
The class should be initialized with a torch tensor in the following ways.
1. directly init.
>>> colo_t1 = ColoTensor(torch.randn(2,3), spec = TensorSpec(distspec.replicate())
>>> # If initializaed in a shard model, the tensor passed in is one shard of the global tensor.
>>> shard_spec = distspec.shard(process_group=gpc.get_group(ParallelMode.DATA),
>>> dims=[0],
>>> num_partitions=[world_size])
>>> tensor_spec = TensorSpec(shard_spec)
>>> colo_t2 = ColoTensor.from_torch_tensor(t_ref.clone(), tensor_spec)
2. use static method from_torch_tensor
>>> colo_t = ColoTensor.from_torch_tensor(torch.randn(2,3), spec = TensorSpec(distspec.replicate())
"""
def __new__(cls, data: torch.Tensor, spec: TensorSpec = TensorSpec(distspec.replicate())) -> 'ColoTensor':
"""__new__
The signature of the __new__ has to be consistent with the torch.Tensor.
Args:
data (torch.Tensor): a torch tensor used as the payload the colotensor.
spec (TensorSpec, optional): the tensor spec of initialization. Defaults to TensorSpec(distspec.replicate())
Returns:
ColoTensor: a ColoTensor wrappers the data.
"""
if data is None:
data = torch.empty(0)
return torch.Tensor._make_subclass(cls, data, data.requires_grad)
def __init__(self, data: torch.Tensor, spec: TensorSpec = TensorSpec(distspec.replicate())) -> None:
self._spec = copy(spec)
self._tensor_spec = copy(spec)
self._type = TensorType.NONMODEL
self._graph_node = None
@property
def spec(self) -> TensorSpec:
return self._spec
return self._tensor_spec
def set_spec(self, spec: TensorSpec) -> None:
spec = copy(spec)
self.convert_to_dist_spec_(spec.dist_spec)
self._spec = spec
self._convert_to_dist_spec(spec.dist_spec)
self._tensor_spec = spec
def has_spec(self) -> bool:
return self._spec.parallel_action is not None
return self._tensor_spec.parallel_action is not None
def is_model_data(self) -> bool:
return self._type == TensorType.MODEL
@@ -74,16 +95,16 @@ class ColoTensor(torch.Tensor):
def is_model_data(self) -> bool:
return self._type == TensorType.MODEL
def convert_to_dist_spec_(self, dist_spec: _DistSpec) -> None:
def _convert_to_dist_spec(self, dist_spec: _DistSpec) -> None:
with DistSpecManager.no_grad():
self.data = DistSpecManager.handle_trans_spec(self, self.spec.dist_spec, dist_spec)
self._spec.dist_spec = dist_spec
self._tensor_spec.dist_spec = dist_spec
def convert_to_dist_spec(self, dist_spec: _DistSpec) -> 'ColoTensor':
spec = copy(self._spec)
spec.dist_spec = dist_spec
tensor_spec = copy(self._tensor_spec)
tensor_spec.dist_spec = dist_spec
ret = DistSpecManager.handle_trans_spec(self, self.spec.dist_spec, dist_spec)
return ColoTensor.from_torch_tensor(ret, spec)
return ColoTensor.from_torch_tensor(ret, tensor_spec)
@staticmethod
def from_torch_tensor(tensor: torch.Tensor, spec: TensorSpec = TensorSpec(distspec.replicate())) -> 'ColoTensor':