diff --git a/colossalai/tensor/colo_tensor.py b/colossalai/tensor/colo_tensor.py index febbb5e94..e309c413f 100644 --- a/colossalai/tensor/colo_tensor.py +++ b/colossalai/tensor/colo_tensor.py @@ -2,7 +2,7 @@ from colossalai.context import parallel_mode from .op_wrapper import _COLOSSAL_OPS import torch -from typing import Tuple, Optional +from typing import Tuple, Optional, Callable from numpy import product from colossalai.core import global_context as gpc from colossalai.nn.layer.utils import divide @@ -152,26 +152,28 @@ class ColoTensor(object): kwargs = {} kwargs = {k: v.torch_tensor() if isinstance(v, ColoTensor) else v for k, v in kwargs.items()} - return ColoTensor.init_from_torch_tensor(func(*args, **kwargs)) + return cls._filter_outputs_with_colo(func(*args,**kwargs)) def backward(self, gradient: Optional[torch.Tensor] = None, retain_graph: bool = False): self._torch_tensor.backward(gradient=gradient, retain_graph=retain_graph) - ## TODO(fjr) we reduce redundency of the following code - def __add__(self, o) -> "ColoTensor": - return ColoTensor.init_from_torch_tensor(self.torch_tensor() + o.torch_tensor()) + def __getattr__(self, name): + def replace_tensor_with_colo(func): + def execute_func(*args, **kwargs): + return self._filter_outputs_with_colo(func(*args,**kwargs)) + return execute_func - def __truediv__(self, o) -> "ColoTensor": - return ColoTensor.init_from_torch_tensor(self.torch_tensor() / o) + attr = getattr(self._torch_tensor, name) + if isinstance(attr, Callable): + return replace_tensor_with_colo(attr) + else: + return attr - def view(self, *args: int) -> "ColoTensor": - return ColoTensor.init_from_torch_tensor(self.torch_tensor().view(*args)) - - def permute(self, *args) -> "ColoTensor": - return ColoTensor.init_from_torch_tensor(self.torch_tensor().permute(*args)) - - def transpose(self, *args) -> "ColoTensor": - return ColoTensor.init_from_torch_tensor(self.torch_tensor().transpose(*args)) - - def contiguous(self): - return ColoTensor.init_from_torch_tensor(self.torch_tensor().contiguous()) + @classmethod + def _filter_outputs_with_colo(cls, outputs): + if outputs is None: # return None + return None + elif type(outputs) is not tuple: # num of return val = 1 + return ColoTensor.init_from_torch_tensor(outputs) if type(outputs) is torch.Tensor else outputs + else: # num of return val > 1 + return tuple([ColoTensor.init_from_torch_tensor(output) if type(output) is torch.Tensor else output for output in outputs]) diff --git a/tests/test_tensor/test_op.py b/tests/test_tensor/test_op.py index 4babb73cd..8482052b0 100644 --- a/tests/test_tensor/test_op.py +++ b/tests/test_tensor/test_op.py @@ -86,12 +86,32 @@ def test_no_wrap_op(): assert torch.sum(t) == torch.sum(t_ref) assert torch.sum(input=t) == torch.sum(input=t_ref) +def test_wrapped_tensor_func(): + t_ref = torch.randn(4, 5) + t = ColoTensor.init_from_torch_tensor(t_ref.clone()) + + # non-func attr + assert t.is_cuda == t_ref.is_cuda + + # TODO I don't find out a tensor function which returns None. + + # return 1 torch.Tensor + t_abs = t.abs() + assert isinstance(t_abs, ColoTensor) and torch.equal(t_abs.torch_tensor(), t_ref.abs()) + + # return 1 non-torch.Tensor + assert t.dim() == t_ref.dim() + + # return >1 torch.Tensor + t_split1, t_split2 = t.split(2) + assert isinstance(t_split1, ColoTensor) and isinstance(t_split2, ColoTensor) + def check_all(): test_linear() test_element_wise() test_no_wrap_op() - + test_wrapped_tensor_func() if __name__ == '__main__': check_all()