[Tensor] add layer norm Op (#852)

This commit is contained in:
Jiarui Fang
2022-04-25 11:49:20 +08:00
committed by GitHub
parent a82da26f7e
commit 126ba573a8
5 changed files with 79 additions and 8 deletions

View File

@@ -1,7 +1,32 @@
from numpy import allclose, require
from numpy import allclose
import torch
from colossalai.tensor import ColoTensor
from copy import deepcopy
from colossalai.utils import get_current_device
def test_layernorm():
ln_op = torch.nn.LayerNorm(2, 3, device=get_current_device())
ln_op_colo = deepcopy(ln_op)
input_t = torch.randn(3, 2, device=get_current_device())
input_t_colo = ColoTensor.init_from_torch_tensor(tensor=input_t.clone().detach())
# prepare colossalai LN
delattr(ln_op_colo, 'weight')
weight_clone = ln_op.weight.clone().detach()
weight_clone.requires_grad = True
setattr(ln_op_colo, 'weight', ColoTensor.init_from_torch_tensor(tensor=weight_clone))
output = ln_op(input_t)
output_colo = ln_op_colo(input_t_colo)
assert allclose(output_colo.torch_tensor().detach().cpu(), output.detach().cpu())
torch.mean(output).backward()
torch.mean(output_colo).backward()
assert allclose(ln_op.weight.grad.cpu(), ln_op_colo.weight.torch_tensor().grad.cpu())
def test_linear():
@@ -50,8 +75,8 @@ def test_element_wise():
t_ref = torch.randn(3, 5)
t = ColoTensor.init_from_torch_tensor(t_ref.clone())
assert torch.mean(t) == torch.mean(t_ref)
assert allclose(torch.nn.functional.gelu(t), torch.nn.functional.gelu(t_ref))
assert allclose(torch.nn.functional.relu(t), torch.nn.functional.relu(t_ref))
assert allclose(torch.nn.functional.gelu(t).torch_tensor(), torch.nn.functional.gelu(t_ref))
assert allclose(torch.nn.functional.relu(t).torch_tensor(), torch.nn.functional.relu(t_ref))
# Test a function not wrapped by
@@ -76,4 +101,5 @@ def check_all():
if __name__ == '__main__':
test_lazy_init_tensor()
# test_lazy_init_ptensor()
test_layernorm()