mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-28 13:05:26 +00:00
[shardformer] supported fused normalization (#4112)
This commit is contained in:
@@ -4,13 +4,16 @@ import torch.nn as nn
|
||||
from transformers import LlamaForCausalLM, LlamaForSequenceClassification
|
||||
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel
|
||||
|
||||
from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D
|
||||
from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D
|
||||
|
||||
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||
|
||||
|
||||
class LlamaPolicy(Policy):
|
||||
|
||||
def config_sanity_check(self):
|
||||
pass
|
||||
|
||||
def preprocess(self):
|
||||
# Resize embedding
|
||||
vocab_size = self.model.config.vocab_size
|
||||
@@ -23,7 +26,7 @@ class LlamaPolicy(Policy):
|
||||
return self.model
|
||||
|
||||
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
|
||||
return {
|
||||
base_policy = {
|
||||
LlamaDecoderLayer:
|
||||
ModulePolicyDescription(
|
||||
attribute_replacement={
|
||||
@@ -75,6 +78,27 @@ class LlamaPolicy(Policy):
|
||||
])
|
||||
}
|
||||
|
||||
# optimization configuration
|
||||
if self.shard_config.enable_fused_normalization:
|
||||
base_policy[LlamaDecoderLayer].sub_module_replacement.extend([
|
||||
SubModuleReplacementDescription(
|
||||
suffix="input_layernorm",
|
||||
target_module=FusedRMSNorm,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="post_attention_layernorm",
|
||||
target_module=FusedRMSNorm,
|
||||
)
|
||||
])
|
||||
|
||||
base_policy[LlamaModel].sub_module_replacement.append(
|
||||
SubModuleReplacementDescription(
|
||||
suffix="norm",
|
||||
target_module=FusedRMSNorm,
|
||||
))
|
||||
|
||||
return base_policy
|
||||
|
||||
def new_model_class(self):
|
||||
return None
|
||||
|
||||
|
Reference in New Issue
Block a user