[tensor] lazy init (#823)

This commit is contained in:
Jiarui Fang
2022-04-21 15:40:23 +08:00
committed by GitHub
parent 68dcd51d41
commit 2ecc3d7a55
2 changed files with 46 additions and 8 deletions

View File

@@ -1,4 +1,4 @@
from numpy import allclose
from numpy import allclose, require
import torch
from colossalai.tensor import ColoTensor
from copy import deepcopy
@@ -14,8 +14,8 @@ def test_linear():
input_ref = torch.randn(1, in_dim)
input_tensor = input_ref.clone()
sharded_weight = ColoTensor(fc_ref.weight)
sharded_bias = ColoTensor(fc_ref.bias)
sharded_weight = ColoTensor.init_from_torch_tensor(fc_ref.weight)
sharded_bias = ColoTensor.init_from_torch_tensor(fc_ref.bias)
# replace the torch nn.Parameters with ShardedTensor
delattr(fc, 'weight')
@@ -48,7 +48,7 @@ def test_linear():
def test_element_wise():
t_ref = torch.randn(3, 5)
t = ColoTensor(t_ref.clone())
t = ColoTensor.init_from_torch_tensor(t_ref.clone())
assert torch.mean(t) == torch.mean(t_ref)
assert allclose(torch.nn.functional.gelu(t), torch.nn.functional.gelu(t_ref))
assert allclose(torch.nn.functional.relu(t), torch.nn.functional.relu(t_ref))
@@ -57,10 +57,16 @@ def test_element_wise():
# Test a function not wrapped by
def test_no_wrap_op():
t_ref = torch.randn(3, 5)
t = ColoTensor(t_ref.clone())
t = ColoTensor.init_from_torch_tensor(t_ref.clone())
assert torch.sum(t) == torch.sum(t_ref)
def test_lazy_init_tensor():
lazy_t = ColoTensor((2, 3), dtype=torch.float32, requires_grad=True)
assert lazy_t._torch_tensor == None
assert lazy_t.torch_tensor().numel() == 6
if __name__ == '__main__':
test_no_wrap_op()
test_lazy_init_tensor()
# test_element_wise()