[hotfix] Add layer norm gradients all-reduce for sequence parallel (#4926)

* [hotfix] Add layer norm gradients all-reduce for sequence parallel. (#4915)

* Add layer norm gradients all-reduce for sequence parallel.

* skip pipeline inference test

* [hotfix] fixing polices of sequence parallel (#4922)

* Add layer norm gradients all-reduce for sequence parallel.

* fix parameter passing when calling get_autopolicy

---------

Co-authored-by: littsk <1214689160@qq.com>

* Hotfix/add grad all reduce for sequence parallel (#4927)

* Add layer norm gradients all-reduce for sequence parallel.


* fix parameter passing when calling get_autopolicy

* fix bug using wrong variables

---------

Co-authored-by: littsk <1214689160@qq.com>

* fix policy initialization

* fix bloom and chatglm policices

* polish code of handling layernorm

* fix moe module

* polish code of class initializing

---------

Co-authored-by: Zhongkai Zhao <kanezz620@gmail.com>
This commit is contained in:
littsk
2023-11-03 13:32:43 +08:00
committed by GitHub
parent d99b2c961a
commit 1a3315e336
30 changed files with 1120 additions and 552 deletions

View File

@@ -11,6 +11,7 @@ from torch.nn import Module
from colossalai.pipeline.stage_manager import PipelineStageManager
from ..layer.normalization import BaseLayerNorm
from ..layer.parallel_module import ParallelModule
from ..shard.shard_config import ShardConfig
@@ -29,7 +30,7 @@ class SubModuleReplacementDescription:
ignore_if_not_exist (bool): if the submodule does not exist, ignore it or raise an exception
"""
suffix: str
target_module: ParallelModule
target_module: Union[ParallelModule, BaseLayerNorm]
kwargs: Dict[str, Any] = None
ignore_if_not_exist: bool = False
@@ -77,7 +78,6 @@ class Policy(ABC):
def set_model(self, model: nn.Module) -> None:
r"""
Set model as an attribute of the Policy object so that we can access the model's attributes.
Args:
model (:class:`nn.Module`): The model to be perform
"""
@@ -86,11 +86,11 @@ class Policy(ABC):
def set_shard_config(self, shard_config: ShardConfig) -> None:
r"""
Set shard config as an attribute of the Policy object.
Args:
shard_config (:class:`ShardConfig`): The shard config to be perform
"""
self.shard_config = shard_config
self.config_sanity_check()
@property