From 126ba573a84eda626d88d516b16f6c18cdea0caa Mon Sep 17 00:00:00 2001 From: Jiarui Fang Date: Mon, 25 Apr 2022 11:49:20 +0800 Subject: [PATCH] [Tensor] add layer norm Op (#852) --- colossalai/tensor/_ops/__init__.py | 3 +- colossalai/tensor/_ops/element_wise.py | 8 ++++-- colossalai/tensor/_ops/layernorm.py | 38 ++++++++++++++++++++++++++ colossalai/tensor/colo_tensor.py | 4 +++ tests/test_tensor/test_op.py | 34 ++++++++++++++++++++--- 5 files changed, 79 insertions(+), 8 deletions(-) create mode 100644 colossalai/tensor/_ops/layernorm.py diff --git a/colossalai/tensor/_ops/__init__.py b/colossalai/tensor/_ops/__init__.py index 0fb96d9fa..d1b945dd2 100644 --- a/colossalai/tensor/_ops/__init__.py +++ b/colossalai/tensor/_ops/__init__.py @@ -1,3 +1,4 @@ from .init import colo_uniform from .linear import colo_linear -from .element_wise import colo_mean \ No newline at end of file +from .element_wise import colo_mean +from .layernorm import colo_layernorm \ No newline at end of file diff --git a/colossalai/tensor/_ops/element_wise.py b/colossalai/tensor/_ops/element_wise.py index 1843784e6..076e3463e 100644 --- a/colossalai/tensor/_ops/element_wise.py +++ b/colossalai/tensor/_ops/element_wise.py @@ -5,8 +5,10 @@ from colossalai.tensor import ColoTensor @colo_op_impl(torch.mean) def colo_mean(types, args=(), kwargs=None, pg=None): - stateful_tensor = args[0] - return torch.mean(stateful_tensor.torch_tensor()) + input_t = args[0] + if isinstance(input_t, ColoTensor): + input_t = input_t.torch_tensor() + return ColoTensor.init_from_torch_tensor(torch.mean(input_t)) def register_elementwise_op(op): @@ -22,7 +24,7 @@ def register_elementwise_op(op): # Validate types if not isinstance(input_tensor, ColoTensor): raise TypeError("input needs to be a ColoTensor") - return op(input_tensor.torch_tensor()) + return ColoTensor.init_from_torch_tensor(op(input_tensor.torch_tensor())) register_elementwise_op(torch.nn.functional.gelu) diff --git a/colossalai/tensor/_ops/layernorm.py b/colossalai/tensor/_ops/layernorm.py new file mode 100644 index 000000000..d616fd104 --- /dev/null +++ b/colossalai/tensor/_ops/layernorm.py @@ -0,0 +1,38 @@ +from numpy import isin, kaiser +import torch +from colossalai.tensor.op_wrapper import colo_op_impl +from colossalai.tensor import ColoTensor + + +@colo_op_impl(torch.nn.functional.layer_norm) +def colo_layernorm(types, args=(), kwargs=None, pg=None): + arg_num = len(args) + if arg_num > 0: + input_tensor = args[0] + if arg_num > 1: + normalized_shape = args[1] + if arg_num > 2: + weight = args[3] + if arg_num > 3: + bias = args[4] + if arg_num > 4: + eps = args[5] + + if 'input' in kwargs: + input_tensor = kwargs['input'] + if 'weight' in kwargs: + weight = kwargs['weight'] + if 'bias' in kwargs: + bias = kwargs['bias'] + if 'eps' in kwargs: + eps = kwargs['eps'] + + if isinstance(input_tensor, ColoTensor): + input_tensor = input_tensor.torch_tensor() + if isinstance(weight, ColoTensor): + weight = weight.torch_tensor() + if isinstance(bias, ColoTensor): + bias = bias.torch_tensor() + + return ColoTensor.init_from_torch_tensor( + torch.nn.functional.layer_norm(input_tensor, normalized_shape, weight, bias, eps)) diff --git a/colossalai/tensor/colo_tensor.py b/colossalai/tensor/colo_tensor.py index 908de7afb..206388b2a 100644 --- a/colossalai/tensor/colo_tensor.py +++ b/colossalai/tensor/colo_tensor.py @@ -8,6 +8,7 @@ from colossalai.context import ParallelMode from colossalai.nn.layer.utils import divide from colossalai.utils.cuda import get_current_device + class ColoTensor(object): """ Data Structure for Tensor in Colossal-AI 1. It contains a torch.Tensor as an attribute. @@ -145,3 +146,6 @@ class ColoTensor(object): kwargs = {k: v.torch_tensor() if isinstance(v, ColoTensor) else v for k, v in kwargs.items()} return func(*args, **kwargs) + + def backward(self, retain_graph: bool = False): + self._torch_tensor.backward(retain_graph=retain_graph) diff --git a/tests/test_tensor/test_op.py b/tests/test_tensor/test_op.py index fd9febf01..7156c536e 100644 --- a/tests/test_tensor/test_op.py +++ b/tests/test_tensor/test_op.py @@ -1,7 +1,32 @@ -from numpy import allclose, require +from numpy import allclose import torch from colossalai.tensor import ColoTensor from copy import deepcopy +from colossalai.utils import get_current_device + + +def test_layernorm(): + ln_op = torch.nn.LayerNorm(2, 3, device=get_current_device()) + ln_op_colo = deepcopy(ln_op) + + input_t = torch.randn(3, 2, device=get_current_device()) + input_t_colo = ColoTensor.init_from_torch_tensor(tensor=input_t.clone().detach()) + + # prepare colossalai LN + delattr(ln_op_colo, 'weight') + weight_clone = ln_op.weight.clone().detach() + weight_clone.requires_grad = True + setattr(ln_op_colo, 'weight', ColoTensor.init_from_torch_tensor(tensor=weight_clone)) + + output = ln_op(input_t) + output_colo = ln_op_colo(input_t_colo) + + assert allclose(output_colo.torch_tensor().detach().cpu(), output.detach().cpu()) + + torch.mean(output).backward() + torch.mean(output_colo).backward() + + assert allclose(ln_op.weight.grad.cpu(), ln_op_colo.weight.torch_tensor().grad.cpu()) def test_linear(): @@ -50,8 +75,8 @@ def test_element_wise(): t_ref = torch.randn(3, 5) t = ColoTensor.init_from_torch_tensor(t_ref.clone()) assert torch.mean(t) == torch.mean(t_ref) - assert allclose(torch.nn.functional.gelu(t), torch.nn.functional.gelu(t_ref)) - assert allclose(torch.nn.functional.relu(t), torch.nn.functional.relu(t_ref)) + assert allclose(torch.nn.functional.gelu(t).torch_tensor(), torch.nn.functional.gelu(t_ref)) + assert allclose(torch.nn.functional.relu(t).torch_tensor(), torch.nn.functional.relu(t_ref)) # Test a function not wrapped by @@ -76,4 +101,5 @@ def check_all(): if __name__ == '__main__': - test_lazy_init_tensor() + # test_lazy_init_ptensor() + test_layernorm()