diff --git a/colossalai/tensor/colo_tensor.py b/colossalai/tensor/colo_tensor.py index 6f82f2c07..cfaac0331 100644 --- a/colossalai/tensor/colo_tensor.py +++ b/colossalai/tensor/colo_tensor.py @@ -1,16 +1,48 @@ import torch from .op_wrapper import _COLOSSAL_OPS +from typing import Tuple class ColoTensor(object): + """ Data Structure for Tensor in Colossal-AI + 1. It contains a torch.Tensor as an attribute. + 2. It supports lazy init the tensor's payload. + 3. It can hijack the torch functions which using ColoTensors as args to our customized functions. + 4. It supports distributing the tensor's payload to the shards among processes. (TODO) + """ def __new__(cls, *args, **kwargs): return super(ColoTensor, cls).__new__(cls) - def __init__(self, t: torch.Tensor) -> None: - self._torch_tensor = t + def __init__( + self, + *size: Tuple[int], + dtype=None, + requires_grad=False, + pin_memory=False, + torch_tensor=None, + ): + self._size = size + self._dtype = dtype + self._requires_grad = requires_grad + self._pin_memory = pin_memory + self._torch_tensor = torch_tensor + + @staticmethod + def init_from_torch_tensor(tensor: torch.Tensor): + colo_t = ColoTensor(*tensor.size(), + dtype=tensor.dtype, + requires_grad=tensor.requires_grad, + pin_memory=tensor.pin_memory, + torch_tensor=tensor) + return colo_t def torch_tensor(self) -> torch.Tensor: + if self._torch_tensor == None: + self._torch_tensor = torch.empty(*self._size, + dtype=self._dtype, + requires_grad=self._requires_grad, + pin_memory=self._pin_memory) return self._torch_tensor @classmethod diff --git a/tests/test_tensor/test_op.py b/tests/test_tensor/test_op.py index 12555f741..3d1719eae 100644 --- a/tests/test_tensor/test_op.py +++ b/tests/test_tensor/test_op.py @@ -1,4 +1,4 @@ -from numpy import allclose +from numpy import allclose, require import torch from colossalai.tensor import ColoTensor from copy import deepcopy @@ -14,8 +14,8 @@ def test_linear(): input_ref = torch.randn(1, in_dim) input_tensor = input_ref.clone() - sharded_weight = ColoTensor(fc_ref.weight) - sharded_bias = ColoTensor(fc_ref.bias) + sharded_weight = ColoTensor.init_from_torch_tensor(fc_ref.weight) + sharded_bias = ColoTensor.init_from_torch_tensor(fc_ref.bias) # replace the torch nn.Parameters with ShardedTensor delattr(fc, 'weight') @@ -48,7 +48,7 @@ def test_linear(): def test_element_wise(): t_ref = torch.randn(3, 5) - t = ColoTensor(t_ref.clone()) + t = ColoTensor.init_from_torch_tensor(t_ref.clone()) assert torch.mean(t) == torch.mean(t_ref) assert allclose(torch.nn.functional.gelu(t), torch.nn.functional.gelu(t_ref)) assert allclose(torch.nn.functional.relu(t), torch.nn.functional.relu(t_ref)) @@ -57,10 +57,16 @@ def test_element_wise(): # Test a function not wrapped by def test_no_wrap_op(): t_ref = torch.randn(3, 5) - t = ColoTensor(t_ref.clone()) + t = ColoTensor.init_from_torch_tensor(t_ref.clone()) assert torch.sum(t) == torch.sum(t_ref) +def test_lazy_init_tensor(): + lazy_t = ColoTensor((2, 3), dtype=torch.float32, requires_grad=True) + assert lazy_t._torch_tensor == None + assert lazy_t.torch_tensor().numel() == 6 + + if __name__ == '__main__': - test_no_wrap_op() + test_lazy_init_tensor() # test_element_wise()