ColossalAI/colossalai/legacy/nn/_ops/layernorm.py
Hongxin Liu 554aa9592e
[legacy] move communication and nn to legacy and refactor logger (#4671)
* [legacy] move communication to legacy (#4640)

* [legacy] refactor logger and clean up legacy codes (#4654)

* [legacy] make logger independent to gpc

* [legacy] make optim independent to registry

* [legacy] move test engine to legacy

* [legacy] move nn to legacy (#4656)

* [legacy] move nn to legacy

* [checkpointio] fix save hf config

* [test] remove useledd rpc pp test

* [legacy] fix nn init

* [example] skip tutorial hybriad parallel example

* [devops] test doc check

* [devops] test doc check
2023-09-11 16:24:28 +08:00

29 lines
1.1 KiB
Python

from typing import List, Optional
import torch.nn.functional as F
from colossalai.tensor import ColoTensor, ColoTensorSpec, ReplicaSpec, distspec
from colossalai.tensor.op_wrapper import colo_op_impl
from ._utils import GeneralTensor, convert_to_colo_tensor
@colo_op_impl(F.layer_norm)
def colo_layernorm(
input_tensor: GeneralTensor,
normalized_shape: List[int],
weight: Optional[GeneralTensor] = None,
bias: Optional[GeneralTensor] = None,
eps: float = 1e-5,
):
assert isinstance(weight, ColoTensor)
input_tensor = convert_to_colo_tensor(input_tensor, weight.get_process_group())
bias = convert_to_colo_tensor(bias, weight.get_process_group())
input_tensor = input_tensor.redistribute(ReplicaSpec())
output = F.layer_norm(input_tensor, normalized_shape, weight=weight, bias=bias, eps=eps)
output = ColoTensor.from_torch_tensor(tensor=output,
spec=ColoTensorSpec(pg=input_tensor.get_process_group(),
dist_attr=input_tensor.dist_spec))
return output