[Tensor] make ColoTensor more robust for getattr (#886)

* [Tensor] make ColoTensor more robust for getattr

* polish

* polish
This commit is contained in:
Jiarui Fang
2022-04-27 10:57:49 +08:00
committed by GitHub
parent 9bc5a77c31
commit 72cdc06875
4 changed files with 58 additions and 28 deletions

View File

@@ -86,32 +86,12 @@ def test_no_wrap_op():
assert torch.sum(t) == torch.sum(t_ref)
assert torch.sum(input=t) == torch.sum(input=t_ref)
def test_wrapped_tensor_func():
t_ref = torch.randn(4, 5)
t = ColoTensor.init_from_torch_tensor(t_ref.clone())
# non-func attr
assert t.is_cuda == t_ref.is_cuda
# TODO I don't find out a tensor function which returns None.
# return 1 torch.Tensor
t_abs = t.abs()
assert isinstance(t_abs, ColoTensor) and torch.equal(t_abs.torch_tensor(), t_ref.abs())
# return 1 non-torch.Tensor
assert t.dim() == t_ref.dim()
# return >1 torch.Tensor
t_split1, t_split2 = t.split(2)
assert isinstance(t_split1, ColoTensor) and isinstance(t_split2, ColoTensor)
def check_all():
test_linear()
test_element_wise()
test_no_wrap_op()
test_wrapped_tensor_func()
if __name__ == '__main__':
check_all()

View File

@@ -13,3 +13,33 @@ def test_lazy_init_tensor():
lazy_t = ColoTensor(2, 3, dtype=torch.float32, requires_grad=True)
assert lazy_t._torch_tensor.numel() == 0
assert lazy_t.numel() == 6 == lazy_t.torch_tensor().numel()
def test_wrapped_tensor_func():
t_ref = torch.randn(4, 5)
t = ColoTensor.init_from_torch_tensor(t_ref.clone())
# non-func attr
assert t.is_cuda == t_ref.is_cuda
# TODO I don't find out a tensor function which returns None.
# return 1 torch.Tensor
t_abs = t.abs()
assert isinstance(t_abs, ColoTensor) and torch.equal(t_abs.torch_tensor(), t_ref.abs())
# return 1 non-torch.Tensor
assert t.dim() == t_ref.dim()
# return >1 torch.Tensor
t_split1, t_split2 = t.split(2)
assert isinstance(t_split1, ColoTensor) and isinstance(t_split2, ColoTensor)
def test_operand():
t_ref = torch.randn(4, 5)
t = ColoTensor.init_from_torch_tensor(t_ref.clone())
t_ref_res = t_ref + t_ref
t_res = t + t
assert torch.allclose(t_ref_res, t_res)