mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 10:06:44 +00:00
[Tensor] add layer norm Op (#852)
This commit is contained in:
@@ -1,7 +1,32 @@
|
||||
from numpy import allclose, require
|
||||
from numpy import allclose
|
||||
import torch
|
||||
from colossalai.tensor import ColoTensor
|
||||
from copy import deepcopy
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
|
||||
def test_layernorm():
|
||||
ln_op = torch.nn.LayerNorm(2, 3, device=get_current_device())
|
||||
ln_op_colo = deepcopy(ln_op)
|
||||
|
||||
input_t = torch.randn(3, 2, device=get_current_device())
|
||||
input_t_colo = ColoTensor.init_from_torch_tensor(tensor=input_t.clone().detach())
|
||||
|
||||
# prepare colossalai LN
|
||||
delattr(ln_op_colo, 'weight')
|
||||
weight_clone = ln_op.weight.clone().detach()
|
||||
weight_clone.requires_grad = True
|
||||
setattr(ln_op_colo, 'weight', ColoTensor.init_from_torch_tensor(tensor=weight_clone))
|
||||
|
||||
output = ln_op(input_t)
|
||||
output_colo = ln_op_colo(input_t_colo)
|
||||
|
||||
assert allclose(output_colo.torch_tensor().detach().cpu(), output.detach().cpu())
|
||||
|
||||
torch.mean(output).backward()
|
||||
torch.mean(output_colo).backward()
|
||||
|
||||
assert allclose(ln_op.weight.grad.cpu(), ln_op_colo.weight.torch_tensor().grad.cpu())
|
||||
|
||||
|
||||
def test_linear():
|
||||
@@ -50,8 +75,8 @@ def test_element_wise():
|
||||
t_ref = torch.randn(3, 5)
|
||||
t = ColoTensor.init_from_torch_tensor(t_ref.clone())
|
||||
assert torch.mean(t) == torch.mean(t_ref)
|
||||
assert allclose(torch.nn.functional.gelu(t), torch.nn.functional.gelu(t_ref))
|
||||
assert allclose(torch.nn.functional.relu(t), torch.nn.functional.relu(t_ref))
|
||||
assert allclose(torch.nn.functional.gelu(t).torch_tensor(), torch.nn.functional.gelu(t_ref))
|
||||
assert allclose(torch.nn.functional.relu(t).torch_tensor(), torch.nn.functional.relu(t_ref))
|
||||
|
||||
|
||||
# Test a function not wrapped by
|
||||
@@ -76,4 +101,5 @@ def check_all():
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_lazy_init_tensor()
|
||||
# test_lazy_init_ptensor()
|
||||
test_layernorm()
|
||||
|
Reference in New Issue
Block a user