[tensor] reorganize files (#820)

This commit is contained in:
Jiarui Fang
2022-04-21 14:15:48 +08:00
committed by GitHub
parent ab962b9735
commit 0ce8924ceb
11 changed files with 71 additions and 76 deletions

View File

@@ -1,10 +1,6 @@
from numpy import allclose
import torch
from torch import nn
from colossalai.gemini.tensor.stateful_tensor import StatefulTensorV2
# TODO(jiaruifang) auto import
from colossalai.gemini.tensor._ops import *
from colossalai.gemini.tensor.api import _STATEFUL_OPS
from colossalai.tensor import ColoTensor
from copy import deepcopy
@@ -18,8 +14,8 @@ def test_linear():
input_ref = torch.randn(1, in_dim)
input_tensor = input_ref.clone()
sharded_weight = StatefulTensorV2(fc_ref.weight)
sharded_bias = StatefulTensorV2(fc_ref.bias)
sharded_weight = ColoTensor(fc_ref.weight)
sharded_bias = ColoTensor(fc_ref.bias)
# replace the torch nn.Parameters with ShardedTensor
delattr(fc, 'weight')
@@ -45,15 +41,14 @@ def test_linear():
# The test case failed
# def test_uniform():
# t = StatefulTensorV2(torch.zeros(3, 5))
# # print(_STATEFUL_OPS)
# t = ColoTensor(torch.zeros(3, 5))
# torch.nn.init.uniform_(t)
# print(t)
def test_element_wise():
t_ref = torch.randn(3, 5)
t = StatefulTensorV2(t_ref.clone())
t = ColoTensor(t_ref.clone())
assert torch.mean(t) == torch.mean(t_ref)
assert allclose(torch.nn.functional.gelu(t), torch.nn.functional.gelu(t_ref))
assert allclose(torch.nn.functional.relu(t), torch.nn.functional.relu(t_ref))