[feat] refactored extension module (#5298)

* [feat] refactored extension module

* polish

* polish

* polish

* polish

* polish

* polish

* polish

* polish

* polish

* polish
This commit is contained in:
Frank Lee
2024-01-25 17:01:48 +08:00
committed by GitHub
parent d7f8db8e21
commit 7cfed5f076
157 changed files with 1353 additions and 8966 deletions

View File

@@ -11,7 +11,6 @@ from torch import Tensor
from torch.nn.parameter import Parameter
from colossalai.accelerator import get_accelerator
from colossalai.kernel import LayerNorm
from colossalai.legacy.communication import broadcast
from colossalai.legacy.context import ParallelMode, seed
from colossalai.legacy.context.parallel_context import global_context as gpc
@@ -23,6 +22,7 @@ from colossalai.legacy.utils.checkpointing import (
partition_tensor_parallel_state_dict,
)
from colossalai.nn import init as init
from colossalai.nn.layer.layernorm import MixedFusedLayerNorm as LayerNorm
from ..base_layer import ParallelLayer
from ..colossalai_layer._utils import ColossalaiModule