mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 19:13:01 +00:00
[ColoTensor] improves init functions. (#1150)
This commit is contained in:
@@ -1,12 +1,13 @@
|
||||
from .op_wrapper import _COLOSSAL_OPS
|
||||
from .const import TensorType
|
||||
from copy import copy
|
||||
import torch
|
||||
from torch.overrides import get_default_nowrap_functions
|
||||
|
||||
from colossalai.tensor import TensorSpec
|
||||
from .const import TensorType
|
||||
from colossalai.tensor import distspec
|
||||
from colossalai.tensor.dist_spec_mgr import DistSpecManager
|
||||
from colossalai.tensor.distspec import _DistSpec
|
||||
from torch.overrides import get_default_nowrap_functions
|
||||
|
||||
|
||||
def _convert_output(output):
|
||||
@@ -18,34 +19,54 @@ def _convert_output(output):
|
||||
|
||||
|
||||
class ColoTensor(torch.Tensor):
|
||||
""" Data Structure for Tensor in Colossal-AI
|
||||
1. It contains a torch.Tensor as an attribute.
|
||||
2. It supports lazy init the tensor's payload.
|
||||
3. It can hijack the torch functions which using ColoTensors as args to our customized functions.
|
||||
4. It supports distributing the tensor's payload to the shards among processes. (TODO)
|
||||
""" Data Structure for Tensor in Colossal-AI. It is a subclass of torch.Tensor.
|
||||
Args:
|
||||
data (torch.Tensor): a torch tensor used as the payload the colotensor.
|
||||
spec (TensorSpec, optional): the tensor spec of initialization. Defaults to TensorSpec(distspec.replicate()).
|
||||
|
||||
The signature of the function has to be consistent with the __new__ except for the 1st arg.
|
||||
The class should be initialized with a torch tensor in the following ways.
|
||||
1. directly init.
|
||||
>>> colo_t1 = ColoTensor(torch.randn(2,3), spec = TensorSpec(distspec.replicate())
|
||||
>>> # If initializaed in a shard model, the tensor passed in is one shard of the global tensor.
|
||||
>>> shard_spec = distspec.shard(process_group=gpc.get_group(ParallelMode.DATA),
|
||||
>>> dims=[0],
|
||||
>>> num_partitions=[world_size])
|
||||
>>> tensor_spec = TensorSpec(shard_spec)
|
||||
>>> colo_t2 = ColoTensor.from_torch_tensor(t_ref.clone(), tensor_spec)
|
||||
2. use static method from_torch_tensor
|
||||
>>> colo_t = ColoTensor.from_torch_tensor(torch.randn(2,3), spec = TensorSpec(distspec.replicate())
|
||||
"""
|
||||
|
||||
def __new__(cls, data: torch.Tensor, spec: TensorSpec = TensorSpec(distspec.replicate())) -> 'ColoTensor':
|
||||
"""__new__
|
||||
The signature of the __new__ has to be consistent with the torch.Tensor.
|
||||
Args:
|
||||
data (torch.Tensor): a torch tensor used as the payload the colotensor.
|
||||
spec (TensorSpec, optional): the tensor spec of initialization. Defaults to TensorSpec(distspec.replicate())
|
||||
Returns:
|
||||
ColoTensor: a ColoTensor wrappers the data.
|
||||
"""
|
||||
if data is None:
|
||||
data = torch.empty(0)
|
||||
return torch.Tensor._make_subclass(cls, data, data.requires_grad)
|
||||
|
||||
def __init__(self, data: torch.Tensor, spec: TensorSpec = TensorSpec(distspec.replicate())) -> None:
|
||||
self._spec = copy(spec)
|
||||
self._tensor_spec = copy(spec)
|
||||
self._type = TensorType.NONMODEL
|
||||
self._graph_node = None
|
||||
|
||||
@property
|
||||
def spec(self) -> TensorSpec:
|
||||
return self._spec
|
||||
return self._tensor_spec
|
||||
|
||||
def set_spec(self, spec: TensorSpec) -> None:
|
||||
spec = copy(spec)
|
||||
self.convert_to_dist_spec_(spec.dist_spec)
|
||||
self._spec = spec
|
||||
self._convert_to_dist_spec(spec.dist_spec)
|
||||
self._tensor_spec = spec
|
||||
|
||||
def has_spec(self) -> bool:
|
||||
return self._spec.parallel_action is not None
|
||||
return self._tensor_spec.parallel_action is not None
|
||||
|
||||
def is_model_data(self) -> bool:
|
||||
return self._type == TensorType.MODEL
|
||||
@@ -74,16 +95,16 @@ class ColoTensor(torch.Tensor):
|
||||
def is_model_data(self) -> bool:
|
||||
return self._type == TensorType.MODEL
|
||||
|
||||
def convert_to_dist_spec_(self, dist_spec: _DistSpec) -> None:
|
||||
def _convert_to_dist_spec(self, dist_spec: _DistSpec) -> None:
|
||||
with DistSpecManager.no_grad():
|
||||
self.data = DistSpecManager.handle_trans_spec(self, self.spec.dist_spec, dist_spec)
|
||||
self._spec.dist_spec = dist_spec
|
||||
self._tensor_spec.dist_spec = dist_spec
|
||||
|
||||
def convert_to_dist_spec(self, dist_spec: _DistSpec) -> 'ColoTensor':
|
||||
spec = copy(self._spec)
|
||||
spec.dist_spec = dist_spec
|
||||
tensor_spec = copy(self._tensor_spec)
|
||||
tensor_spec.dist_spec = dist_spec
|
||||
ret = DistSpecManager.handle_trans_spec(self, self.spec.dist_spec, dist_spec)
|
||||
return ColoTensor.from_torch_tensor(ret, spec)
|
||||
return ColoTensor.from_torch_tensor(ret, tensor_spec)
|
||||
|
||||
@staticmethod
|
||||
def from_torch_tensor(tensor: torch.Tensor, spec: TensorSpec = TensorSpec(distspec.replicate())) -> 'ColoTensor':
|
||||
|
Reference in New Issue
Block a user