mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 20:10:17 +00:00
[hotfix] Add layer norm gradients all-reduce for sequence parallel (#4926)
* [hotfix] Add layer norm gradients all-reduce for sequence parallel. (#4915) * Add layer norm gradients all-reduce for sequence parallel. * skip pipeline inference test * [hotfix] fixing polices of sequence parallel (#4922) * Add layer norm gradients all-reduce for sequence parallel. * fix parameter passing when calling get_autopolicy --------- Co-authored-by: littsk <1214689160@qq.com> * Hotfix/add grad all reduce for sequence parallel (#4927) * Add layer norm gradients all-reduce for sequence parallel. * fix parameter passing when calling get_autopolicy * fix bug using wrong variables --------- Co-authored-by: littsk <1214689160@qq.com> * fix policy initialization * fix bloom and chatglm policices * polish code of handling layernorm * fix moe module * polish code of class initializing --------- Co-authored-by: Zhongkai Zhao <kanezz620@gmail.com>
This commit is contained in:
@@ -1,11 +1,14 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.lazy import LazyInitContext
|
||||
|
||||
__all__ = ["FusedLayerNorm", "FusedRMSNorm"]
|
||||
from .utils import SeqParallelUtils
|
||||
|
||||
__all__ = ["FusedLayerNorm", "FusedRMSNorm", "LayerNorm", "RMSNorm", "BaseLayerNorm"]
|
||||
|
||||
FAST_LAYERNORM_SUPPORTED_SIZE = [
|
||||
1024,
|
||||
@@ -35,7 +38,103 @@ FAST_LAYERNORM_SUPPORTED_SIZE = [
|
||||
]
|
||||
|
||||
|
||||
class FusedLayerNorm:
|
||||
class BaseLayerNorm(ABC):
|
||||
@abstractmethod
|
||||
def from_native_module(module: nn.Module, sp_partial_derived: bool = False):
|
||||
"""
|
||||
Convert a native PyTorch layer normalization module to a specific layer normalization module,
|
||||
and optionally mark parameters for gradient aggregation.
|
||||
|
||||
Args:
|
||||
module (nn.Module): The native PyTorch layer normalization module to be converted.
|
||||
sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism.
|
||||
|
||||
Returns:
|
||||
nn.Module: The specific layer normalization module.
|
||||
|
||||
Raises:
|
||||
AssertionError: If the provided module is not an instance of the supported layer normalization type.
|
||||
"""
|
||||
|
||||
|
||||
class RMSNorm(BaseLayerNorm):
|
||||
r"""
|
||||
This is a wrapper around the RMSNorm. It is meant to be used only with the from_native_module interface.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
raise NotImplementedError(
|
||||
"FusedLayerNorm is not implemented as a physical class. "
|
||||
"It is meant to be used only with the from_native_module interface to convert a native RMSNorm module to colossalai layer norm module."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def from_native_module(module: nn.Module, sp_partial_derived: bool = False, *args, **kwargs) -> nn.Module:
|
||||
"""
|
||||
Convert a native RMSNorm module to colossalai layer norm module,
|
||||
and optionally mark parameters for gradient aggregation.
|
||||
|
||||
Args:
|
||||
module (nn.Module): The native RMSNorm module to be converted.
|
||||
sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism.
|
||||
|
||||
Returns:
|
||||
nn.Module: The RMSNorm module.
|
||||
"""
|
||||
|
||||
LazyInitContext.materialize(module)
|
||||
|
||||
if sp_partial_derived:
|
||||
# Since gradients are computed using only a subset of the data,
|
||||
# aggregation of these gradients is necessary during backpropagation.
|
||||
# Therefore, we annotate these parameters in advance to indicate the need for gradient aggregation.
|
||||
SeqParallelUtils.marked_as_sp_partial_derived_param(module.weight)
|
||||
|
||||
return module
|
||||
|
||||
|
||||
class LayerNorm(BaseLayerNorm):
|
||||
r"""
|
||||
This is a wrapper around the torch.nn.LayerNorm. It is meant to be used only with the from_native_module interface.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
raise NotImplementedError(
|
||||
"LayerNorm is not implemented as a physical class. "
|
||||
"It is meant to be used only with the from_native_module interface to convert a native pytorch layer norm module to colossalai layer norm module."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def from_native_module(module: nn.LayerNorm, sp_partial_derived: bool = False, *args, **kwargs) -> nn.Module:
|
||||
r"""
|
||||
Convert a native pytorch layer norm module to colossalai layer norm module,
|
||||
and optionally marking parameters for gradient aggregation.
|
||||
|
||||
Args:
|
||||
module (nn.LayerNorm): The native PyTorch LayerNorm module to be converted.
|
||||
sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism.
|
||||
|
||||
Returns:
|
||||
nn.Module: The LayerNorm module.
|
||||
|
||||
Raises:
|
||||
AssertionError: If the provided module is not an instance of nn.LayerNorm.
|
||||
"""
|
||||
assert isinstance(module, nn.LayerNorm), "Only support conversion from nn.LayerNorm."
|
||||
|
||||
LazyInitContext.materialize(module)
|
||||
|
||||
if sp_partial_derived:
|
||||
# Since gradients are computed using only a subset of the data,
|
||||
# aggregation of these gradients is necessary during backpropagation.
|
||||
# Therefore, we annotate these parameters in advance to indicate the need for gradient aggregation.
|
||||
SeqParallelUtils.marked_as_sp_partial_derived_param(module.weight)
|
||||
SeqParallelUtils.marked_as_sp_partial_derived_param(module.bias)
|
||||
|
||||
return module
|
||||
|
||||
|
||||
class FusedLayerNorm(BaseLayerNorm):
|
||||
r"""
|
||||
This is a wrapper around the apex fused layernorm implementation. It is meant to be used only with the from_native_module interface.
|
||||
"""
|
||||
@@ -43,15 +142,29 @@ class FusedLayerNorm:
|
||||
def __init__(self) -> None:
|
||||
raise NotImplementedError(
|
||||
"FusedLayerNorm is not implemented as a physical class. "
|
||||
"It is meant to be used only with the from_native_module interface to wrap the fused layernorm implementation provided by apex."
|
||||
"It is meant to be used only with the from_native_module interface convert a native pytorch layer norm module to FusedLayerNorm module provided by apex."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def from_native_module(module: nn.LayerNorm, *args, **kwargs) -> nn.Module:
|
||||
def from_native_module(module: nn.LayerNorm, sp_partial_derived: bool = False, *args, **kwargs) -> nn.Module:
|
||||
r"""
|
||||
Convert a native pytorch layer norm module to colossalai layer norm module
|
||||
Convert a native pytorch layer norm module to FusedLayerNorm module provided by apex,
|
||||
and optionally marking parameters for gradient aggregation.
|
||||
|
||||
Args:
|
||||
module (nn.LayerNorm): The native PyTorch LayerNorm module to be converted.
|
||||
sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism.
|
||||
|
||||
Returns:
|
||||
nn.Module: Union[FastLayerNorm, FusedLayerNorm].
|
||||
|
||||
Raises:
|
||||
AssertionError: If the provided module is not an instance of nn.LayerNorm.
|
||||
"""
|
||||
# check if apex is installed
|
||||
|
||||
assert isinstance(module, nn.LayerNorm), "Only support conversion from nn.LayerNorm."
|
||||
|
||||
try:
|
||||
pass
|
||||
except ImportError:
|
||||
@@ -85,10 +198,18 @@ class FusedLayerNorm:
|
||||
|
||||
layernorm.weight = module.weight
|
||||
layernorm.bias = module.bias
|
||||
|
||||
if sp_partial_derived:
|
||||
# Since gradients are computed using only a subset of the data,
|
||||
# aggregation of these gradients is necessary during backpropagation.
|
||||
# Therefore, we annotate these parameters in advance to indicate the need for gradient aggregation.
|
||||
SeqParallelUtils.marked_as_sp_partial_derived_param(layernorm.weight)
|
||||
SeqParallelUtils.marked_as_sp_partial_derived_param(layernorm.bias)
|
||||
|
||||
return layernorm
|
||||
|
||||
|
||||
class FusedRMSNorm:
|
||||
class FusedRMSNorm(BaseLayerNorm):
|
||||
"""
|
||||
This is a wrapper around the apex fused rms norm implementation. It is meant to be used only with the from_native_module interface.
|
||||
"""
|
||||
@@ -96,11 +217,22 @@ class FusedRMSNorm:
|
||||
def __init__(self) -> None:
|
||||
raise NotImplementedError(
|
||||
"FusedRMSNorm is not implemented as a physical class. "
|
||||
"It is meant to be used only with the from_native_module interface to wrap the fused rms norm implementation provided by apex."
|
||||
"It is meant to be used only with the from_native_module interface to Convert a native RMSNorm module to FusedRMSNorm module provided by apex."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def from_native_module(module: nn.Module, *args, **kwargs) -> nn.Module:
|
||||
def from_native_module(module: nn.Module, sp_partial_derived: bool = False, *args, **kwargs) -> nn.Module:
|
||||
r"""
|
||||
Convert a native RMSNorm module module to FusedRMSNorm module provided by apex,
|
||||
and optionally marking parameters for gradient aggregation.
|
||||
|
||||
Args:
|
||||
module (nn.LayerNorm): The native PyTorch LayerNorm module to be converted.
|
||||
sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism.
|
||||
|
||||
Returns:
|
||||
nn.Module: FusedRMSNorm module.
|
||||
"""
|
||||
try:
|
||||
from apex.normalization import FusedRMSNorm as ApexFusedRMSNorm
|
||||
except ImportError:
|
||||
@@ -124,4 +256,10 @@ class FusedRMSNorm:
|
||||
|
||||
rmsnorm.weight = module.weight
|
||||
|
||||
if sp_partial_derived:
|
||||
# Since gradients are computed using only a subset of the data,
|
||||
# aggregation of these gradients is necessary during backpropagation.
|
||||
# Therefore, we annotate these parameters in advance to indicate the need for gradient aggregation.
|
||||
SeqParallelUtils.marked_as_sp_partial_derived_param(rmsnorm.weight)
|
||||
|
||||
return rmsnorm
|
||||
|
Reference in New Issue
Block a user