mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 01:28:31 +00:00
[tensor] wrap function in the torch_tensor to ColoTensor (#881)
This commit is contained in:
@@ -86,12 +86,32 @@ def test_no_wrap_op():
|
||||
assert torch.sum(t) == torch.sum(t_ref)
|
||||
assert torch.sum(input=t) == torch.sum(input=t_ref)
|
||||
|
||||
def test_wrapped_tensor_func():
|
||||
t_ref = torch.randn(4, 5)
|
||||
t = ColoTensor.init_from_torch_tensor(t_ref.clone())
|
||||
|
||||
# non-func attr
|
||||
assert t.is_cuda == t_ref.is_cuda
|
||||
|
||||
# TODO I don't find out a tensor function which returns None.
|
||||
|
||||
# return 1 torch.Tensor
|
||||
t_abs = t.abs()
|
||||
assert isinstance(t_abs, ColoTensor) and torch.equal(t_abs.torch_tensor(), t_ref.abs())
|
||||
|
||||
# return 1 non-torch.Tensor
|
||||
assert t.dim() == t_ref.dim()
|
||||
|
||||
# return >1 torch.Tensor
|
||||
t_split1, t_split2 = t.split(2)
|
||||
assert isinstance(t_split1, ColoTensor) and isinstance(t_split2, ColoTensor)
|
||||
|
||||
|
||||
def check_all():
|
||||
test_linear()
|
||||
test_element_wise()
|
||||
test_no_wrap_op()
|
||||
|
||||
test_wrapped_tensor_func()
|
||||
|
||||
if __name__ == '__main__':
|
||||
check_all()
|
||||
|
Reference in New Issue
Block a user