mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-31 16:40:41 +00:00
[Tensor] add layer norm Op (#852)
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
from .init import colo_uniform
|
||||
from .linear import colo_linear
|
||||
from .element_wise import colo_mean
|
||||
from .element_wise import colo_mean
|
||||
from .layernorm import colo_layernorm
|
@@ -5,8 +5,10 @@ from colossalai.tensor import ColoTensor
|
||||
|
||||
@colo_op_impl(torch.mean)
|
||||
def colo_mean(types, args=(), kwargs=None, pg=None):
|
||||
stateful_tensor = args[0]
|
||||
return torch.mean(stateful_tensor.torch_tensor())
|
||||
input_t = args[0]
|
||||
if isinstance(input_t, ColoTensor):
|
||||
input_t = input_t.torch_tensor()
|
||||
return ColoTensor.init_from_torch_tensor(torch.mean(input_t))
|
||||
|
||||
|
||||
def register_elementwise_op(op):
|
||||
@@ -22,7 +24,7 @@ def register_elementwise_op(op):
|
||||
# Validate types
|
||||
if not isinstance(input_tensor, ColoTensor):
|
||||
raise TypeError("input needs to be a ColoTensor")
|
||||
return op(input_tensor.torch_tensor())
|
||||
return ColoTensor.init_from_torch_tensor(op(input_tensor.torch_tensor()))
|
||||
|
||||
|
||||
register_elementwise_op(torch.nn.functional.gelu)
|
||||
|
38
colossalai/tensor/_ops/layernorm.py
Normal file
38
colossalai/tensor/_ops/layernorm.py
Normal file
@@ -0,0 +1,38 @@
|
||||
from numpy import isin, kaiser
|
||||
import torch
|
||||
from colossalai.tensor.op_wrapper import colo_op_impl
|
||||
from colossalai.tensor import ColoTensor
|
||||
|
||||
|
||||
@colo_op_impl(torch.nn.functional.layer_norm)
|
||||
def colo_layernorm(types, args=(), kwargs=None, pg=None):
|
||||
arg_num = len(args)
|
||||
if arg_num > 0:
|
||||
input_tensor = args[0]
|
||||
if arg_num > 1:
|
||||
normalized_shape = args[1]
|
||||
if arg_num > 2:
|
||||
weight = args[3]
|
||||
if arg_num > 3:
|
||||
bias = args[4]
|
||||
if arg_num > 4:
|
||||
eps = args[5]
|
||||
|
||||
if 'input' in kwargs:
|
||||
input_tensor = kwargs['input']
|
||||
if 'weight' in kwargs:
|
||||
weight = kwargs['weight']
|
||||
if 'bias' in kwargs:
|
||||
bias = kwargs['bias']
|
||||
if 'eps' in kwargs:
|
||||
eps = kwargs['eps']
|
||||
|
||||
if isinstance(input_tensor, ColoTensor):
|
||||
input_tensor = input_tensor.torch_tensor()
|
||||
if isinstance(weight, ColoTensor):
|
||||
weight = weight.torch_tensor()
|
||||
if isinstance(bias, ColoTensor):
|
||||
bias = bias.torch_tensor()
|
||||
|
||||
return ColoTensor.init_from_torch_tensor(
|
||||
torch.nn.functional.layer_norm(input_tensor, normalized_shape, weight, bias, eps))
|
@@ -8,6 +8,7 @@ from colossalai.context import ParallelMode
|
||||
from colossalai.nn.layer.utils import divide
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
|
||||
|
||||
class ColoTensor(object):
|
||||
""" Data Structure for Tensor in Colossal-AI
|
||||
1. It contains a torch.Tensor as an attribute.
|
||||
@@ -145,3 +146,6 @@ class ColoTensor(object):
|
||||
|
||||
kwargs = {k: v.torch_tensor() if isinstance(v, ColoTensor) else v for k, v in kwargs.items()}
|
||||
return func(*args, **kwargs)
|
||||
|
||||
def backward(self, retain_graph: bool = False):
|
||||
self._torch_tensor.backward(retain_graph=retain_graph)
|
||||
|
Reference in New Issue
Block a user