mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 02:26:51 +00:00
[tensor] lazy init (#823)
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
from numpy import allclose
|
||||
from numpy import allclose, require
|
||||
import torch
|
||||
from colossalai.tensor import ColoTensor
|
||||
from copy import deepcopy
|
||||
@@ -14,8 +14,8 @@ def test_linear():
|
||||
input_ref = torch.randn(1, in_dim)
|
||||
input_tensor = input_ref.clone()
|
||||
|
||||
sharded_weight = ColoTensor(fc_ref.weight)
|
||||
sharded_bias = ColoTensor(fc_ref.bias)
|
||||
sharded_weight = ColoTensor.init_from_torch_tensor(fc_ref.weight)
|
||||
sharded_bias = ColoTensor.init_from_torch_tensor(fc_ref.bias)
|
||||
|
||||
# replace the torch nn.Parameters with ShardedTensor
|
||||
delattr(fc, 'weight')
|
||||
@@ -48,7 +48,7 @@ def test_linear():
|
||||
|
||||
def test_element_wise():
|
||||
t_ref = torch.randn(3, 5)
|
||||
t = ColoTensor(t_ref.clone())
|
||||
t = ColoTensor.init_from_torch_tensor(t_ref.clone())
|
||||
assert torch.mean(t) == torch.mean(t_ref)
|
||||
assert allclose(torch.nn.functional.gelu(t), torch.nn.functional.gelu(t_ref))
|
||||
assert allclose(torch.nn.functional.relu(t), torch.nn.functional.relu(t_ref))
|
||||
@@ -57,10 +57,16 @@ def test_element_wise():
|
||||
# Test a function not wrapped by
|
||||
def test_no_wrap_op():
|
||||
t_ref = torch.randn(3, 5)
|
||||
t = ColoTensor(t_ref.clone())
|
||||
t = ColoTensor.init_from_torch_tensor(t_ref.clone())
|
||||
assert torch.sum(t) == torch.sum(t_ref)
|
||||
|
||||
|
||||
def test_lazy_init_tensor():
|
||||
lazy_t = ColoTensor((2, 3), dtype=torch.float32, requires_grad=True)
|
||||
assert lazy_t._torch_tensor == None
|
||||
assert lazy_t.torch_tensor().numel() == 6
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_no_wrap_op()
|
||||
test_lazy_init_tensor()
|
||||
# test_element_wise()
|
||||
|
Reference in New Issue
Block a user