[shardformer] supported fused normalization (#4112)

This commit is contained in:
Frank Lee
2023-06-30 09:32:37 +08:00
parent b1c2901530
commit f3b6aaa6b7
12 changed files with 207 additions and 31 deletions

View File

@@ -4,13 +4,16 @@ import torch.nn as nn
from transformers import LlamaForCausalLM, LlamaForSequenceClassification
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel
from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D
from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
class LlamaPolicy(Policy):
def config_sanity_check(self):
pass
def preprocess(self):
# Resize embedding
vocab_size = self.model.config.vocab_size
@@ -23,7 +26,7 @@ class LlamaPolicy(Policy):
return self.model
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
return {
base_policy = {
LlamaDecoderLayer:
ModulePolicyDescription(
attribute_replacement={
@@ -75,6 +78,27 @@ class LlamaPolicy(Policy):
])
}
# optimization configuration
if self.shard_config.enable_fused_normalization:
base_policy[LlamaDecoderLayer].sub_module_replacement.extend([
SubModuleReplacementDescription(
suffix="input_layernorm",
target_module=FusedRMSNorm,
),
SubModuleReplacementDescription(
suffix="post_attention_layernorm",
target_module=FusedRMSNorm,
)
])
base_policy[LlamaModel].sub_module_replacement.append(
SubModuleReplacementDescription(
suffix="norm",
target_module=FusedRMSNorm,
))
return base_policy
def new_model_class(self):
return None