mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 18:19:58 +00:00
[tensor] reorganize files (#820)
This commit is contained in:
@@ -1,10 +1,6 @@
|
||||
from numpy import allclose
|
||||
import torch
|
||||
from torch import nn
|
||||
from colossalai.gemini.tensor.stateful_tensor import StatefulTensorV2
|
||||
# TODO(jiaruifang) auto import
|
||||
from colossalai.gemini.tensor._ops import *
|
||||
from colossalai.gemini.tensor.api import _STATEFUL_OPS
|
||||
from colossalai.tensor import ColoTensor
|
||||
from copy import deepcopy
|
||||
|
||||
|
||||
@@ -18,8 +14,8 @@ def test_linear():
|
||||
input_ref = torch.randn(1, in_dim)
|
||||
input_tensor = input_ref.clone()
|
||||
|
||||
sharded_weight = StatefulTensorV2(fc_ref.weight)
|
||||
sharded_bias = StatefulTensorV2(fc_ref.bias)
|
||||
sharded_weight = ColoTensor(fc_ref.weight)
|
||||
sharded_bias = ColoTensor(fc_ref.bias)
|
||||
|
||||
# replace the torch nn.Parameters with ShardedTensor
|
||||
delattr(fc, 'weight')
|
||||
@@ -45,15 +41,14 @@ def test_linear():
|
||||
|
||||
# The test case failed
|
||||
# def test_uniform():
|
||||
# t = StatefulTensorV2(torch.zeros(3, 5))
|
||||
# # print(_STATEFUL_OPS)
|
||||
# t = ColoTensor(torch.zeros(3, 5))
|
||||
# torch.nn.init.uniform_(t)
|
||||
# print(t)
|
||||
|
||||
|
||||
def test_element_wise():
|
||||
t_ref = torch.randn(3, 5)
|
||||
t = StatefulTensorV2(t_ref.clone())
|
||||
t = ColoTensor(t_ref.clone())
|
||||
assert torch.mean(t) == torch.mean(t_ref)
|
||||
assert allclose(torch.nn.functional.gelu(t), torch.nn.functional.gelu(t_ref))
|
||||
assert allclose(torch.nn.functional.relu(t), torch.nn.functional.relu(t_ref))
|
Reference in New Issue
Block a user