mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-24 02:30:56 +00:00
* [legacy] move communication to legacy (#4640) * [legacy] refactor logger and clean up legacy codes (#4654) * [legacy] make logger independent to gpc * [legacy] make optim independent to registry * [legacy] move test engine to legacy * [legacy] move nn to legacy (#4656) * [legacy] move nn to legacy * [checkpointio] fix save hf config * [test] remove useledd rpc pp test * [legacy] fix nn init * [example] skip tutorial hybriad parallel example * [devops] test doc check * [devops] test doc check
29 lines
1.1 KiB
Python
29 lines
1.1 KiB
Python
from typing import List, Optional
|
|
|
|
import torch.nn.functional as F
|
|
|
|
from colossalai.tensor import ColoTensor, ColoTensorSpec, ReplicaSpec, distspec
|
|
from colossalai.tensor.op_wrapper import colo_op_impl
|
|
|
|
from ._utils import GeneralTensor, convert_to_colo_tensor
|
|
|
|
|
|
@colo_op_impl(F.layer_norm)
|
|
def colo_layernorm(
|
|
input_tensor: GeneralTensor,
|
|
normalized_shape: List[int],
|
|
weight: Optional[GeneralTensor] = None,
|
|
bias: Optional[GeneralTensor] = None,
|
|
eps: float = 1e-5,
|
|
):
|
|
assert isinstance(weight, ColoTensor)
|
|
input_tensor = convert_to_colo_tensor(input_tensor, weight.get_process_group())
|
|
bias = convert_to_colo_tensor(bias, weight.get_process_group())
|
|
input_tensor = input_tensor.redistribute(ReplicaSpec())
|
|
|
|
output = F.layer_norm(input_tensor, normalized_shape, weight=weight, bias=bias, eps=eps)
|
|
output = ColoTensor.from_torch_tensor(tensor=output,
|
|
spec=ColoTensorSpec(pg=input_tensor.get_process_group(),
|
|
dist_attr=input_tensor.dist_spec))
|
|
return output
|