diff --git a/colossalai/shardformer/layer/__init__.py b/colossalai/shardformer/layer/__init__.py index 2826a8429..7fad4948d 100644 --- a/colossalai/shardformer/layer/__init__.py +++ b/colossalai/shardformer/layer/__init__.py @@ -1,12 +1,12 @@ from .dropout import DropoutForParallelInput, DropoutForReplicatedInput from .embedding import Embedding1D, VocabParallelEmbedding1D -from .layernorm import FusedLayerNorm from .linear import Linear1D_Col, Linear1D_Row from .loss import cross_entropy_1d +from .normalization import FusedLayerNorm, FusedRMSNorm from .qkv_fused_linear import GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row __all__ = [ "Embedding1D", "VocabParallelEmbedding1D", "Linear1D_Col", "Linear1D_Row", 'GPT2FusedLinearConv1D_Col', 'GPT2FusedLinearConv1D_Row', 'DropoutForParallelInput', 'DropoutForReplicatedInput', "cross_entropy_1d", - 'FusedLayerNorm' + 'FusedLayerNorm', 'FusedRMSNorm' ] diff --git a/colossalai/shardformer/layer/layernorm.py b/colossalai/shardformer/layer/normalization.py similarity index 59% rename from colossalai/shardformer/layer/layernorm.py rename to colossalai/shardformer/layer/normalization.py index 6103380fe..b27307154 100644 --- a/colossalai/shardformer/layer/layernorm.py +++ b/colossalai/shardformer/layer/normalization.py @@ -4,7 +4,7 @@ import torch import torch.nn as nn -__all__ = ['FusedLayerNorm'] +__all__ = ['FusedLayerNorm', 'FusedRMSNorm'] FAST_LAYERNORM_SUPPORTED_SIZE = [ 1024, 1536, 2048, 2304, 3072, 3840, 4096, 5120, 6144, 8192, 10240, 12288, 12800, 15360, 16384, 18432, 20480, 24576, @@ -61,4 +61,44 @@ class FusedLayerNorm(): # copy weight and bias layernorm.weight.copy_(module.weight) layernorm.bias.copy_(module.bias) - return layernorm \ No newline at end of file + return layernorm + + +class FusedRMSNorm(): + """ + This is a wrapper around the apex fused rms norm implementation. It is meant to be used only with the from_native_module interface. + """ + + def __init__(self) -> None: + raise NotImplementedError( + 'FusedRMSNorm is not implemented as a physical class. ' + 'It is meant to be used only with the from_native_module interface to wrap the fused rms norm implementation provided by apex.' + ) + + @staticmethod + def from_native_module(module: nn.Module, *args, **kwargs) -> nn.Module: + try: + from apex.normalization import FusedRMSNorm as ApexFusedRMSNorm + except ImportError: + raise ImportError( + 'Please install apex from source (https://github.com/NVIDIA/apex) to use the fused RMS normalization kernel' + ) + + # to check if it is huggingface LlamaRMSNorm + if module.__class__.__name__ == "LlamaRMSNorm": + normalized_shape = module.weight.shape[0] + eps = module.variance_epsilon + elementwise_affine = True + else: + # get the attributes of the module + normalized_shape = module.normalized_shape + eps = module.eps + elementwise_affine = module.elementwise_affine + + rmsnorm = ApexFusedRMSNorm(normalized_shape=normalized_shape, eps=eps, elementwise_affine=elementwise_affine) + + with torch.no_grad(): + # copy weight and bias + rmsnorm.weight.copy_(module.weight) + + return rmsnorm diff --git a/colossalai/shardformer/policies/basepolicy.py b/colossalai/shardformer/policies/basepolicy.py index 7e9bcf209..8835e38cb 100644 --- a/colossalai/shardformer/policies/basepolicy.py +++ b/colossalai/shardformer/policies/basepolicy.py @@ -98,6 +98,14 @@ class Policy(ABC): shard_config (:class:`ShardConfig`): The shard config to be perform """ self.shard_config = shard_config + self.config_sanity_check() + + @abstractmethod + def config_sanity_check(self): + """ + Check if the shard config is valid for the model. Raise an exception if the config is invalid. + """ + pass @abstractmethod def preprocess(self) -> nn.Module: diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index 49ef53259..545669f1f 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -16,6 +16,9 @@ from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDes class BertPolicy(Policy): + def config_sanity_check(self): + pass + def preprocess(self): # reshape the embedding layer r""" @@ -99,7 +102,8 @@ class BertPolicy(Policy): ]) } - if self.shard_config.fused_layernorm: + # optimization configuration + if self.shard_config.enable_fused_normalization: base_policy[BertLayer].sub_module_replacement.append( SubModuleReplacementDescription( suffix="attention.output.LayerNorm", @@ -150,12 +154,16 @@ class BertForPretrainingPolicy(BertPolicy): kwargs={"gather_output": True}), ]) } - if self.shard_config.fused_layernorm: + + # optimization configuration + if self.shard_config.enable_fused_normalization: addon_module[BertLMPredictionHead].sub_module_replacement.append( SubModuleReplacementDescription( suffix="transform.LayerNorm", target_module=col_nn.FusedLayerNorm, )) + + # append extra policy module_policy.update(addon_module) return module_policy @@ -187,7 +195,7 @@ class BertLMHeadModelPolicy(BertPolicy): kwargs={"gather_output": True}), ]) } - if self.shard_config.fused_layernorm: + if self.shard_config.enable_fused_normalization: addon_module[BertLMPredictionHead].sub_module_replacement.append( SubModuleReplacementDescription( suffix="transform.LayerNorm", @@ -224,12 +232,15 @@ class BertForMaskedLMPolicy(BertPolicy): kwargs={"gather_output": True}), ]) } - if self.shard_config.fused_layernorm: + + # optimization configuration + if self.shard_config.enable_fused_normalization: addon_module[BertLMPredictionHead].sub_module_replacement.append( SubModuleReplacementDescription( suffix="transform.LayerNorm", target_module=col_nn.FusedLayerNorm, )) + module_policy.update(addon_module) return module_policy @@ -316,4 +327,4 @@ class BertForMultipleChoicePolicy(BertPolicy): ]) } module_policy.update(addon_module) - return module_policy \ No newline at end of file + return module_policy diff --git a/colossalai/shardformer/policies/bloom.py b/colossalai/shardformer/policies/bloom.py index d196bdbd6..4e34f2464 100644 --- a/colossalai/shardformer/policies/bloom.py +++ b/colossalai/shardformer/policies/bloom.py @@ -65,6 +65,9 @@ def build_bloom_alibi_tensor(self, attention_mask: torch.Tensor, num_heads: int, class BloomPolicy(Policy): + def config_sanity_check(self): + pass + def preprocess(self): # reshape the embedding layer r""" @@ -81,7 +84,7 @@ class BloomPolicy(Policy): def module_policy(self): from transformers.models.bloom.modeling_bloom import BloomBlock, BloomModel - return { + base_policy = { BloomBlock: ModulePolicyDescription( attribute_replacement={ @@ -99,7 +102,6 @@ class BloomPolicy(Policy): SubModuleReplacementDescription( suffix="self_attention.query_key_value", target_module=col_nn.Linear1D_Col, - # kwargs={'n_fused': 3} ), SubModuleReplacementDescription( suffix="self_attention.dense", @@ -132,6 +134,31 @@ class BloomPolicy(Policy): ]) } + # optimization configuration + if self.shard_config.enable_fused_normalization: + base_policy[BloomModel].sub_module_replacement.extend([ + SubModuleReplacementDescription( + suffix="ln_f", + target_module=col_nn.FusedLayerNorm, + ), + SubModuleReplacementDescription( + suffix="word_embeddings_layernorm", + target_module=col_nn.FusedLayerNorm, + ) + ]) + base_policy[BloomBlock].sub_module_replacement.extend([ + SubModuleReplacementDescription( + suffix="input_layernorm", + target_module=col_nn.FusedLayerNorm, + ), + SubModuleReplacementDescription( + suffix="post_attention_layernorm", + target_module=col_nn.FusedLayerNorm, + ) + ]) + + return base_policy + def new_model_class(self): # do nothing return self.model diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index ebfaf8a8e..3d6d94b8e 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -9,6 +9,9 @@ from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDes class GPT2Policy(Policy): + def config_sanity_check(self): + pass + def preprocess(self): # reshape the embedding layer r""" @@ -22,7 +25,7 @@ class GPT2Policy(Policy): return self.model def module_policy(self): - return { + base_policy = { GPT2Model: ModulePolicyDescription(attribute_replacement={}, param_replacement=[], @@ -77,6 +80,30 @@ class GPT2Policy(Policy): ]) } + # optimization configuration + if self.shard_config.enable_fused_normalization: + base_policy[GPT2Model].sub_module_replacement.append( + SubModuleReplacementDescription( + suffix="ln_f", + target_module=col_nn.FusedLayerNorm, + )) + + base_policy[GPT2Block].sub_module_replacement.extend([ + SubModuleReplacementDescription( + suffix="ln_1", + target_module=col_nn.FusedLayerNorm, + ), + SubModuleReplacementDescription( + suffix="ln_2", + target_module=col_nn.FusedLayerNorm, + ), + SubModuleReplacementDescription(suffix="ln_cross_attn", + target_module=col_nn.FusedLayerNorm, + ignore_if_not_exist=True) + ]) + + return base_policy + def new_model_class(self): return self.model diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index a13f5f087..b36180ce3 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -4,13 +4,16 @@ import torch.nn as nn from transformers import LlamaForCausalLM, LlamaForSequenceClassification from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel -from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D +from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription class LlamaPolicy(Policy): + def config_sanity_check(self): + pass + def preprocess(self): # Resize embedding vocab_size = self.model.config.vocab_size @@ -23,7 +26,7 @@ class LlamaPolicy(Policy): return self.model def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: - return { + base_policy = { LlamaDecoderLayer: ModulePolicyDescription( attribute_replacement={ @@ -75,6 +78,27 @@ class LlamaPolicy(Policy): ]) } + # optimization configuration + if self.shard_config.enable_fused_normalization: + base_policy[LlamaDecoderLayer].sub_module_replacement.extend([ + SubModuleReplacementDescription( + suffix="input_layernorm", + target_module=FusedRMSNorm, + ), + SubModuleReplacementDescription( + suffix="post_attention_layernorm", + target_module=FusedRMSNorm, + ) + ]) + + base_policy[LlamaModel].sub_module_replacement.append( + SubModuleReplacementDescription( + suffix="norm", + target_module=FusedRMSNorm, + )) + + return base_policy + def new_model_class(self): return None diff --git a/colossalai/shardformer/policies/opt.py b/colossalai/shardformer/policies/opt.py index f467726e5..ce3873954 100644 --- a/colossalai/shardformer/policies/opt.py +++ b/colossalai/shardformer/policies/opt.py @@ -13,6 +13,9 @@ from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDes class OPTPolicy(Policy): + def config_sanity_check(self): + pass + def preprocess(self): # reshape the embedding layer r""" @@ -74,7 +77,9 @@ class OPTPolicy(Policy): ), ]), } - if self.shard_config.fused_layernorm: + + # optimization configuration + if self.shard_config.enable_fused_normalization: base_policy[OPTDecoder].sub_module_replacement.append( SubModuleReplacementDescription(suffix="final_layer_norm", target_module=FusedLayerNorm, @@ -87,6 +92,7 @@ class OPTPolicy(Policy): target_module=FusedLayerNorm, ignore_if_not_exist=True) ]) + return base_policy def new_model_class(self): diff --git a/colossalai/shardformer/policies/t5.py b/colossalai/shardformer/policies/t5.py index 8d8abc9f7..d35f688a0 100644 --- a/colossalai/shardformer/policies/t5.py +++ b/colossalai/shardformer/policies/t5.py @@ -9,7 +9,7 @@ from transformers.models.t5.modeling_t5 import ( T5Stack, ) -from colossalai.shardformer.layer import DropoutForParallelInput, Embedding1D, Linear1D_Col, Linear1D_Row +from colossalai.shardformer.layer import DropoutForParallelInput, Embedding1D, FusedRMSNorm, Linear1D_Col, Linear1D_Row from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription @@ -18,6 +18,9 @@ __all__ = ["T5ModelPolicy", "T5ForConditionalGenerationPolicy", "T5EncoderPolicy class T5ModelPolicy(Policy): + def config_sanity_check(self): + pass + def preprocess(self): # reshape the embedding layer r""" @@ -31,7 +34,7 @@ class T5ModelPolicy(Policy): return self.model def module_policy(self): - return { + base_policy = { T5Stack: ModulePolicyDescription(attribute_replacement={}, param_replacement=[], @@ -139,6 +142,19 @@ class T5ModelPolicy(Policy): ]) } + # optimization configuration + if self.shard_config.enable_fused_normalization: + base_policy[T5LayerFF].sub_module_replacement.append( + SubModuleReplacementDescription(suffix="layer_norm", target_module=FusedRMSNorm)) + base_policy[T5LayerSelfAttention].sub_module_replacement.append( + SubModuleReplacementDescription(suffix="layer_norm", target_module=FusedRMSNorm)) + base_policy[T5LayerCrossAttention].sub_module_replacement.append( + SubModuleReplacementDescription(suffix="layer_norm", target_module=FusedRMSNorm)) + base_policy[T5Stack].sub_module_replacement.append( + SubModuleReplacementDescription(suffix="final_layer_norm", target_module=FusedRMSNorm)) + + return base_policy + def new_model_class(self): return None @@ -167,4 +183,4 @@ class T5ForConditionalGenerationPolicy(T5ModelPolicy): class T5EncoderPolicy(T5ModelPolicy): - pass \ No newline at end of file + pass diff --git a/colossalai/shardformer/policies/vit.py b/colossalai/shardformer/policies/vit.py index 550f8f997..5d8a235db 100644 --- a/colossalai/shardformer/policies/vit.py +++ b/colossalai/shardformer/policies/vit.py @@ -3,13 +3,16 @@ from typing import Dict, Union import torch.nn as nn from transformers.models.vit.modeling_vit import ViTAttention, ViTEmbeddings, ViTLayer, ViTModel -from colossalai.shardformer.layer import DropoutForReplicatedInput, Linear1D_Col, Linear1D_Row +from colossalai.shardformer.layer import DropoutForReplicatedInput, FusedLayerNorm, Linear1D_Col, Linear1D_Row from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription class ViTPolicy(Policy): + def config_sanity_check(self): + pass + def preprocess(self): # Resize embedding vocab_size = self.model.config.vocab_size @@ -22,7 +25,7 @@ class ViTPolicy(Policy): return self.model def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: - return { + base_policy = { ViTEmbeddings: ModulePolicyDescription(attribute_replacement={}, param_replacement=[], @@ -80,6 +83,26 @@ class ViTPolicy(Policy): ]), } + # optimization configuration + if self.shard_config.enable_fused_normalization: + base_policy[ViTAttention].sub_module_replacement.extend([ + SubModuleReplacementDescription( + suffix="layernorm_before", + target_module=FusedLayerNorm, + ), + SubModuleReplacementDescription( + suffix="layernorm_after", + target_module=FusedLayerNorm, + ) + ]) + base_policy[ViTModel].sub_module_replacement.append( + SubModuleReplacementDescription( + suffix="layernorm", + target_module=FusedLayerNorm, + )) + + return base_policy + def new_model_class(self): return None diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index 8d3fc225e..428ebc978 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -12,16 +12,10 @@ class ShardConfig: Args: tensor_parallel_size (int): The size of tensor parallel - use_mixedfusedLN (bool): Whether to use the `MixedFusedLayerNorm` - data_parallel_size (int): The size of data parallel - pipeline_parallel_size (int): The size of pipeline parallel - tensor_parallel_mode (List): The mode of tensor parallel, choose from `['1d','2d','2.5d','3d'] - inference_only (bool): Whether to use the inference only mode, when setting to `True`, the model - will not calculate the loss and just return the output. - gather_output (bool): Whether to gather the output of the model of the last layer + enable_fused_normalization (bool): Whether to use fused layernorm, default is False """ tensor_parallel_size: int - fused_layernorm: bool = False + enable_fused_normalization: bool = False # TODO: add support for tensor parallel # pipeline_parallel_size: int diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index ad7c408ae..e49b0246c 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -8,11 +8,11 @@ def build_model(world_size, model_fn): org_model = model_fn().cuda() # shard model - shard_config = ShardConfig(tensor_parallel_size=world_size, fused_layernorm=True) + shard_config = ShardConfig(tensor_parallel_size=world_size, enable_fused_normalization=True) model_copy = copy.deepcopy(org_model) shard_former = ShardFormer(shard_config=shard_config) shard_former.init_distributed() - sharded_model = shard_former.shard_model(model_copy) + sharded_model = shard_former.shard_model(model_copy).cuda() return org_model, sharded_model @@ -33,4 +33,4 @@ def run_forward(original_model, sharded_model, data_gen_fn, output_transform_fn, shard_output = sharded_model(**data) shard_output = output_transform_fn(shard_output) shard_loss = loss_fn(shard_output) - return org_output, org_loss, shard_output, shard_loss \ No newline at end of file + return org_output, org_loss, shard_output, shard_loss