mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-27 20:46:00 +00:00
[shardformer] support llama model using shardformer (#3969)
adjust layer attr
This commit is contained in:
122
colossalai/shardformer/policies/llama.py
Normal file
122
colossalai/shardformer/policies/llama.py
Normal file
@@ -0,0 +1,122 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Dict, List, Tuple, Type
|
||||
|
||||
import torch.nn as nn
|
||||
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel
|
||||
|
||||
import colossalai.shardformer.layer.layers as col_nn
|
||||
|
||||
from .basepolicy import Argument, Col_Layer, Policy, Row_Layer
|
||||
|
||||
|
||||
class LlamaPolicy(Policy):
|
||||
|
||||
@staticmethod
|
||||
def argument_policy(config, world_size: int) -> Dict[nn.Module, Argument]:
|
||||
return {
|
||||
LlamaDecoderLayer:
|
||||
Argument(attr_dict={
|
||||
"self_attn.hidden_size": config.hidden_size // world_size,
|
||||
"self_attn.num_heads": config.num_attention_heads // world_size,
|
||||
},
|
||||
param_funcs=[LlamaPolicy.attn_layer, LlamaPolicy.mlp_layer]),
|
||||
LlamaModel:
|
||||
Argument(attr_dict={}, param_funcs=[LlamaPolicy.embeddings])
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def attn_layer() -> List:
|
||||
return [
|
||||
Col_Layer(
|
||||
suffix="self_attn.q_proj",
|
||||
weight="weight",
|
||||
bias="bias",
|
||||
replace_layer=col_nn.Linear1D_Col,
|
||||
),
|
||||
Col_Layer(
|
||||
suffix="self_attn.k_proj",
|
||||
weight="weight",
|
||||
bias="bias",
|
||||
replace_layer=col_nn.Linear1D_Col,
|
||||
),
|
||||
Col_Layer(
|
||||
suffix="self_attn.v_proj",
|
||||
weight="weight",
|
||||
bias="bias",
|
||||
replace_layer=col_nn.Linear1D_Col,
|
||||
),
|
||||
Row_Layer(
|
||||
suffix="self_attn.o_proj",
|
||||
weight="weight",
|
||||
bias="bias",
|
||||
replace_layer=col_nn.Linear1D_Row,
|
||||
)
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def mlp_layer() -> List:
|
||||
return [
|
||||
Col_Layer(
|
||||
suffix="mlp.gate_proj",
|
||||
weight="weight",
|
||||
bias="bias",
|
||||
replace_layer=col_nn.Linear1D_Col,
|
||||
gather_output=True,
|
||||
),
|
||||
Col_Layer(
|
||||
suffix="mlp.up_proj",
|
||||
weight="weight",
|
||||
bias="bias",
|
||||
replace_layer=col_nn.Linear1D_Row,
|
||||
gather_output=True,
|
||||
),
|
||||
Col_Layer(
|
||||
suffix="mlp.down_proj",
|
||||
weight="weight",
|
||||
bias="bias",
|
||||
replace_layer=col_nn.Linear1D_Col,
|
||||
gather_output=True,
|
||||
),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def embeddings() -> List:
|
||||
return [Col_Layer(
|
||||
suffix="embed_tokens",
|
||||
weight="weight",
|
||||
replace_layer=col_nn.VocabParallelEmbedding1D,
|
||||
)]
|
||||
|
||||
from transformers import LlamaForCausalLM
|
||||
|
||||
|
||||
class LlamaForCausalLMPolicy(LlamaPolicy):
|
||||
|
||||
@staticmethod
|
||||
def argument(config, world_size):
|
||||
llamapolicy = LlamaPolicy.argument_policy(config, world_size)
|
||||
argument = {LlamaForCausalLM: Argument(attr_dict={}, param_funcs=[LlamaForCausalLMPolicy.lm_head])}
|
||||
argument.update(llamapolicy)
|
||||
|
||||
@staticmethod
|
||||
def lm_head() -> List:
|
||||
return [Col_Layer(suffix="lm_head", weight="weight", replace_layer=col_nn.Linear1D_Col, gather_output=True)]
|
||||
|
||||
|
||||
from transformers import LlamaForSequenceClassification
|
||||
|
||||
|
||||
class LlamaForSequenceClassificationPolicy(LlamaPolicy):
|
||||
|
||||
@staticmethod
|
||||
def argument(config, world_size):
|
||||
llamapolicy = LlamaPolicy.argument_policy(config, world_size)
|
||||
argument = {
|
||||
LlamaForSequenceClassification:
|
||||
Argument(attr_dict={}, param_funcs=[LlamaForSequenceClassificationPolicy.score])
|
||||
}
|
||||
argument.update(llamapolicy)
|
||||
|
||||
@staticmethod
|
||||
def score() -> List:
|
||||
return [Col_Layer(suffix="score", weight="weight", replace_layer=col_nn.Linear1D_Col, gather_output=True)]
|
Reference in New Issue
Block a user