Remove CohereLayerNorm and use existing layernorm

This commit is contained in:
GuangyaoZhang
2024-06-14 09:14:01 +00:00
parent fe2e74c03a
commit 7a2b08646f
3 changed files with 22 additions and 136 deletions

View File

@@ -7,8 +7,8 @@ from torch import Tensor
from torch.nn import Module
from colossalai.shardformer.layer import (
CohereLayerNorm,
FusedCohereLayerNorm,
FusedLayerNorm,
LayerNorm,
Linear1D_Col,
Linear1D_Row,
PaddingEmbedding,
@@ -64,9 +64,9 @@ class CommandPolicy(Policy):
embedding_cls = PaddingEmbedding
if self.shard_config.enable_fused_normalization:
norm_cls = FusedCohereLayerNorm
norm_cls = FusedLayerNorm
else:
norm_cls = CohereLayerNorm
norm_cls = LayerNorm
if self.pipeline_stage_manager is not None:
self.shard_config.enable_sequence_parallelism = False