[hotfix] the bug of numel() in ColoTensor (#845)

This commit is contained in:
Jiarui Fang
2022-04-24 12:32:10 +08:00
committed by GitHub
parent c1e8d2001e
commit ea0a2ed25f
2 changed files with 21 additions and 6 deletions

View File

@@ -3,6 +3,7 @@ import torch
from colossalai.tensor import ColoTensor
from copy import deepcopy
def test_linear():
in_dim = 4
out_dim = 5
@@ -44,6 +45,7 @@ def test_linear():
# torch.nn.init.uniform_(t)
# print(t)
def test_element_wise():
t_ref = torch.randn(3, 5)
t = ColoTensor.init_from_torch_tensor(t_ref.clone())
@@ -59,10 +61,12 @@ def test_no_wrap_op():
assert torch.sum(t) == torch.sum(t_ref)
assert torch.sum(input=t) == torch.sum(input=t_ref)
def test_lazy_init_tensor():
lazy_t = ColoTensor(2, 3, dtype=torch.float32, requires_grad=True)
assert lazy_t._torch_tensor.numel() == 0
assert lazy_t.torch_tensor().numel() == 6
assert lazy_t.numel() == 6 == lazy_t.torch_tensor().numel()
def check_all():
test_linear()
@@ -70,5 +74,6 @@ def check_all():
test_no_wrap_op()
test_lazy_init_tensor()
if __name__ == '__main__':
check_all()
test_lazy_init_tensor()