[shardformer] refactored layernorm (#4086)

This commit is contained in:
Frank Lee
2023-06-26 18:05:00 +08:00
parent c4b1b65931
commit d33a44e8c3
4 changed files with 51 additions and 77 deletions

View File

@@ -1,11 +1,11 @@
from .dropout import Dropout1D
from .embedding import Embedding1D, VocabParallelEmbedding1D
from .layernorm import LayerNorm1D
from .layernorm import FusedLayerNorm
from .linear import Linear1D_Col, Linear1D_Row
from .linear_conv import LinearConv1D_Col, LinearConv1D_Row
from .loss import cross_entropy_1d
__all__ = [
"Embedding1D", "VocabParallelEmbedding1D", "Linear1D_Col", "Linear1D_Row", "LinearConv1D_Col", "LinearConv1D_Row",
"Dropout1D", "cross_entropy_1d", 'LayerNorm1D'
"Dropout1D", "cross_entropy_1d", 'FusedLayerNorm'
]