mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 09:38:05 +00:00
[hotfix] the bug of numel() in ColoTensor (#845)
This commit is contained in:
@@ -3,6 +3,7 @@ import torch
|
||||
from colossalai.tensor import ColoTensor
|
||||
from copy import deepcopy
|
||||
|
||||
|
||||
def test_linear():
|
||||
in_dim = 4
|
||||
out_dim = 5
|
||||
@@ -44,6 +45,7 @@ def test_linear():
|
||||
# torch.nn.init.uniform_(t)
|
||||
# print(t)
|
||||
|
||||
|
||||
def test_element_wise():
|
||||
t_ref = torch.randn(3, 5)
|
||||
t = ColoTensor.init_from_torch_tensor(t_ref.clone())
|
||||
@@ -59,10 +61,12 @@ def test_no_wrap_op():
|
||||
assert torch.sum(t) == torch.sum(t_ref)
|
||||
assert torch.sum(input=t) == torch.sum(input=t_ref)
|
||||
|
||||
|
||||
def test_lazy_init_tensor():
|
||||
lazy_t = ColoTensor(2, 3, dtype=torch.float32, requires_grad=True)
|
||||
assert lazy_t._torch_tensor.numel() == 0
|
||||
assert lazy_t.torch_tensor().numel() == 6
|
||||
assert lazy_t.numel() == 6 == lazy_t.torch_tensor().numel()
|
||||
|
||||
|
||||
def check_all():
|
||||
test_linear()
|
||||
@@ -70,5 +74,6 @@ def check_all():
|
||||
test_no_wrap_op()
|
||||
test_lazy_init_tensor()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
check_all()
|
||||
test_lazy_init_tensor()
|
||||
|
Reference in New Issue
Block a user