mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 01:28:31 +00:00
[gemini] a new tensor structure (#818)
* Revert "[zero] add ZeroTensorShardStrategy (#793)"
This reverts commit 88759e289e
.
* [gemini] set cpu memory capacity
* [log] local throughput collecting
* polish
* polish
* polish
* polish code
* polish
* polish code
* add a new tensor structure and override linear for it
* polish
* polish
* polish
* polish
* polish
* polish
* polish
* polish
* polish
* polish
* polish
This commit is contained in:
64
tests/test_gemini/test_tensor.py
Normal file
64
tests/test_gemini/test_tensor.py
Normal file
@@ -0,0 +1,64 @@
|
||||
from numpy import allclose
|
||||
import torch
|
||||
from torch import nn
|
||||
from colossalai.gemini.tensor.stateful_tensor import StatefulTensorV2
|
||||
# TODO(jiaruifang) auto import
|
||||
from colossalai.gemini.tensor._ops import *
|
||||
from colossalai.gemini.tensor.api import _STATEFUL_OPS
|
||||
from copy import deepcopy
|
||||
|
||||
|
||||
def test_linear():
|
||||
in_dim = 4
|
||||
out_dim = 5
|
||||
|
||||
fc = torch.nn.Linear(in_dim, out_dim, bias=True)
|
||||
fc_ref = deepcopy(fc)
|
||||
|
||||
input_ref = torch.randn(1, in_dim)
|
||||
input_tensor = input_ref.clone()
|
||||
|
||||
sharded_weight = StatefulTensorV2(fc_ref.weight)
|
||||
sharded_bias = StatefulTensorV2(fc_ref.bias)
|
||||
|
||||
# replace the torch nn.Parameters with ShardedTensor
|
||||
delattr(fc, 'weight')
|
||||
setattr(fc, 'weight', sharded_weight)
|
||||
delattr(fc, 'bias')
|
||||
setattr(fc, 'bias', sharded_bias)
|
||||
|
||||
fc.weight.requires_grad = True
|
||||
fc.bias.requires_grad = True
|
||||
|
||||
# torch.nn.functional.linear(torch.randn(1, in_dim), sharded_weight, sharded_bias)
|
||||
out = fc(input_tensor)
|
||||
loss = out.sum()
|
||||
loss.backward()
|
||||
|
||||
out_ref = fc_ref(input_ref)
|
||||
loss_ref = out_ref.sum()
|
||||
loss_ref.backward()
|
||||
|
||||
assert (loss_ref == loss)
|
||||
assert allclose(fc_ref.weight.grad, fc.weight.torch_tensor().grad)
|
||||
|
||||
|
||||
# The test case failed
|
||||
# def test_uniform():
|
||||
# t = StatefulTensorV2(torch.zeros(3, 5))
|
||||
# # print(_STATEFUL_OPS)
|
||||
# torch.nn.init.uniform_(t)
|
||||
# print(t)
|
||||
|
||||
|
||||
def test_element_wise():
|
||||
t_ref = torch.randn(3, 5)
|
||||
t = StatefulTensorV2(t_ref.clone())
|
||||
assert torch.mean(t) == torch.mean(t_ref)
|
||||
assert allclose(torch.nn.functional.gelu(t), torch.nn.functional.gelu(t_ref))
|
||||
assert allclose(torch.nn.functional.relu(t), torch.nn.functional.relu(t_ref))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_linear()
|
||||
# test_element_wise()
|
Reference in New Issue
Block a user