From 72cdc068755820481549f7f0b7ab5a1c652aee22 Mon Sep 17 00:00:00 2001 From: Jiarui Fang Date: Wed, 27 Apr 2022 10:57:49 +0800 Subject: [PATCH] [Tensor] make ColoTensor more robust for getattr (#886) * [Tensor] make ColoTensor more robust for getattr * polish * polish --- colossalai/tensor/_ops/element_wise.py | 3 ++- colossalai/tensor/colo_tensor.py | 31 +++++++++++++++++++++----- tests/test_tensor/test_op.py | 22 +----------------- tests/test_tensor/test_tensor.py | 30 +++++++++++++++++++++++++ 4 files changed, 58 insertions(+), 28 deletions(-) diff --git a/colossalai/tensor/_ops/element_wise.py b/colossalai/tensor/_ops/element_wise.py index 98f449188..e39b2a5b5 100644 --- a/colossalai/tensor/_ops/element_wise.py +++ b/colossalai/tensor/_ops/element_wise.py @@ -12,7 +12,8 @@ def colo_mean(types, args=(), kwargs=None, pg=None): a = a.torch_tensor() elif isinstance(b, ColoTensor): b = b.torch_tensor() - + if kwargs is None: + kwargs = {} return torch.allclose(a, b, **kwargs) diff --git a/colossalai/tensor/colo_tensor.py b/colossalai/tensor/colo_tensor.py index e309c413f..e22fa5850 100644 --- a/colossalai/tensor/colo_tensor.py +++ b/colossalai/tensor/colo_tensor.py @@ -152,18 +152,34 @@ class ColoTensor(object): kwargs = {} kwargs = {k: v.torch_tensor() if isinstance(v, ColoTensor) else v for k, v in kwargs.items()} - return cls._filter_outputs_with_colo(func(*args,**kwargs)) + return cls._filter_outputs_with_colo(func(*args, **kwargs)) def backward(self, gradient: Optional[torch.Tensor] = None, retain_graph: bool = False): self._torch_tensor.backward(gradient=gradient, retain_graph=retain_graph) + def __add__(self, o) -> "ColoTensor": + return ColoTensor.init_from_torch_tensor(self.torch_tensor() + o.torch_tensor()) + + def __truediv__(self, o) -> "ColoTensor": + return ColoTensor.init_from_torch_tensor(self.torch_tensor() / o) + def __getattr__(self, name): + def replace_tensor_with_colo(func): + def execute_func(*args, **kwargs): - return self._filter_outputs_with_colo(func(*args,**kwargs)) + # transform the ColoTensor args to torch Tensor. + args = [arg.torch_tensor() if isinstance(arg, ColoTensor) else arg for arg in args] + if kwargs is None: + kwargs = {} + kwargs = {k: v.torch_tensor() if isinstance(v, ColoTensor) else v for k, v in kwargs.items()} + return self._filter_outputs_with_colo(func(*args, **kwargs)) + return execute_func + assert hasattr(self._torch_tensor, name), f"torch.Tensor has not attribute named as {name}. So is ColoTensor" attr = getattr(self._torch_tensor, name) + if isinstance(attr, Callable): return replace_tensor_with_colo(attr) else: @@ -171,9 +187,12 @@ class ColoTensor(object): @classmethod def _filter_outputs_with_colo(cls, outputs): - if outputs is None: # return None + if outputs is None: # return None return None - elif type(outputs) is not tuple: # num of return val = 1 + elif type(outputs) is not tuple: # num of return val = 1 return ColoTensor.init_from_torch_tensor(outputs) if type(outputs) is torch.Tensor else outputs - else: # num of return val > 1 - return tuple([ColoTensor.init_from_torch_tensor(output) if type(output) is torch.Tensor else output for output in outputs]) + else: # num of return val > 1 + return tuple([ + ColoTensor.init_from_torch_tensor(output) if type(output) is torch.Tensor else output + for output in outputs + ]) diff --git a/tests/test_tensor/test_op.py b/tests/test_tensor/test_op.py index 8482052b0..4babb73cd 100644 --- a/tests/test_tensor/test_op.py +++ b/tests/test_tensor/test_op.py @@ -86,32 +86,12 @@ def test_no_wrap_op(): assert torch.sum(t) == torch.sum(t_ref) assert torch.sum(input=t) == torch.sum(input=t_ref) -def test_wrapped_tensor_func(): - t_ref = torch.randn(4, 5) - t = ColoTensor.init_from_torch_tensor(t_ref.clone()) - - # non-func attr - assert t.is_cuda == t_ref.is_cuda - - # TODO I don't find out a tensor function which returns None. - - # return 1 torch.Tensor - t_abs = t.abs() - assert isinstance(t_abs, ColoTensor) and torch.equal(t_abs.torch_tensor(), t_ref.abs()) - - # return 1 non-torch.Tensor - assert t.dim() == t_ref.dim() - - # return >1 torch.Tensor - t_split1, t_split2 = t.split(2) - assert isinstance(t_split1, ColoTensor) and isinstance(t_split2, ColoTensor) - def check_all(): test_linear() test_element_wise() test_no_wrap_op() - test_wrapped_tensor_func() + if __name__ == '__main__': check_all() diff --git a/tests/test_tensor/test_tensor.py b/tests/test_tensor/test_tensor.py index c7216a9f8..f75eadf84 100644 --- a/tests/test_tensor/test_tensor.py +++ b/tests/test_tensor/test_tensor.py @@ -13,3 +13,33 @@ def test_lazy_init_tensor(): lazy_t = ColoTensor(2, 3, dtype=torch.float32, requires_grad=True) assert lazy_t._torch_tensor.numel() == 0 assert lazy_t.numel() == 6 == lazy_t.torch_tensor().numel() + + +def test_wrapped_tensor_func(): + t_ref = torch.randn(4, 5) + t = ColoTensor.init_from_torch_tensor(t_ref.clone()) + + # non-func attr + assert t.is_cuda == t_ref.is_cuda + + # TODO I don't find out a tensor function which returns None. + + # return 1 torch.Tensor + t_abs = t.abs() + assert isinstance(t_abs, ColoTensor) and torch.equal(t_abs.torch_tensor(), t_ref.abs()) + + # return 1 non-torch.Tensor + assert t.dim() == t_ref.dim() + + # return >1 torch.Tensor + t_split1, t_split2 = t.split(2) + assert isinstance(t_split1, ColoTensor) and isinstance(t_split2, ColoTensor) + + +def test_operand(): + t_ref = torch.randn(4, 5) + t = ColoTensor.init_from_torch_tensor(t_ref.clone()) + + t_ref_res = t_ref + t_ref + t_res = t + t + assert torch.allclose(t_ref_res, t_res)