mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-26 12:14:02 +00:00
[hotfix] Add layer norm gradients all-reduce for sequence parallel (#4926)
* [hotfix] Add layer norm gradients all-reduce for sequence parallel. (#4915) * Add layer norm gradients all-reduce for sequence parallel. * skip pipeline inference test * [hotfix] fixing polices of sequence parallel (#4922) * Add layer norm gradients all-reduce for sequence parallel. * fix parameter passing when calling get_autopolicy --------- Co-authored-by: littsk <1214689160@qq.com> * Hotfix/add grad all reduce for sequence parallel (#4927) * Add layer norm gradients all-reduce for sequence parallel. * fix parameter passing when calling get_autopolicy * fix bug using wrong variables --------- Co-authored-by: littsk <1214689160@qq.com> * fix policy initialization * fix bloom and chatglm policices * polish code of handling layernorm * fix moe module * polish code of class initializing --------- Co-authored-by: Zhongkai Zhao <kanezz620@gmail.com>
This commit is contained in:
@@ -11,6 +11,7 @@ from torch.nn import Module
|
||||
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
|
||||
from ..layer.normalization import BaseLayerNorm
|
||||
from ..layer.parallel_module import ParallelModule
|
||||
from ..shard.shard_config import ShardConfig
|
||||
|
||||
@@ -29,7 +30,7 @@ class SubModuleReplacementDescription:
|
||||
ignore_if_not_exist (bool): if the submodule does not exist, ignore it or raise an exception
|
||||
"""
|
||||
suffix: str
|
||||
target_module: ParallelModule
|
||||
target_module: Union[ParallelModule, BaseLayerNorm]
|
||||
kwargs: Dict[str, Any] = None
|
||||
ignore_if_not_exist: bool = False
|
||||
|
||||
@@ -77,7 +78,6 @@ class Policy(ABC):
|
||||
def set_model(self, model: nn.Module) -> None:
|
||||
r"""
|
||||
Set model as an attribute of the Policy object so that we can access the model's attributes.
|
||||
|
||||
Args:
|
||||
model (:class:`nn.Module`): The model to be perform
|
||||
"""
|
||||
@@ -86,11 +86,11 @@ class Policy(ABC):
|
||||
def set_shard_config(self, shard_config: ShardConfig) -> None:
|
||||
r"""
|
||||
Set shard config as an attribute of the Policy object.
|
||||
|
||||
Args:
|
||||
shard_config (:class:`ShardConfig`): The shard config to be perform
|
||||
"""
|
||||
self.shard_config = shard_config
|
||||
|
||||
self.config_sanity_check()
|
||||
|
||||
@property
|
||||
|
Reference in New Issue
Block a user