[shardformer] adapted llama to the new API (#4036)

This commit is contained in:
Frank Lee
2023-06-19 13:53:17 +08:00
parent 74d176c8d8
commit c1d5453e9f
9 changed files with 238 additions and 201 deletions

View File

@@ -1,64 +1,76 @@
import importlib
from dataclasses import dataclass
import torch.nn as nn
from .basepolicy import Policy
def build_policies():
r"""
Build the policies for the model
Return:
The dict for the policies
@dataclass
class PolicyLocation:
"""
auto_policy_dict = {}
PolicyLocation describes the location of a policy class.
from transformers import BertModel
Args:
file_name (str): The file name of the policy under colossalai.shardformer.policies
class_name (str): The class name of the policy class
"""
file_name: str
class_name: str
from .bert import BertModelPolicy
auto_policy_dict[BertModel] = BertModelPolicy
from transformers import BertForPreTraining
# we don't want to import all policies here
# as each policy file imports its own model zoo library
# we will allow the user to only import the policy file needed
_POLICY_LIST = {
# BERT
"transformers.models.bert.modeling_bert.BertModel":
PolicyLocation(file_name="bert", class_name="BertPolicy"),
"transformers.models.bert.modeling_bert.BertForPreTraining":
PolicyLocation(file_name="bert", class_name="BertForPretrainingPolicy"),
"transformers.models.bert.modeling_bert.BertForMaskedLM":
PolicyLocation(file_name="bert", class_name="BertForMaskedLMPolicy"),
"transformers.models.bert.modeling_bert.BertLMHeadModel":
PolicyLocation(file_name="bert", class_name="BertLMHeadModelPolicy"),
"transformers.models.bert.modeling_bert.BertForNextSentencePrediction":
PolicyLocation(file_name="bert", class_name="BertForNextSentencePredictionPolicy"),
"transformers.models.bert.modeling_bert.BertForSequenceClassification":
PolicyLocation(file_name="bert", class_name="BertForSequenceClassificationPolicy"),
"transformers.models.bert.modeling_bert.BertForMultipleChoice":
PolicyLocation(file_name="bert", class_name="BertForMultipleChoicePolicy"),
from .bert import BertForPretrainingPolicy
auto_policy_dict[BertForPreTraining] = BertForPretrainingPolicy
# LLaMA
"transformers.models.llama.modeling_llama.LlamaModel":
PolicyLocation(file_name="llama", class_name="LlamaPolicy"),
"transformers.models.llama.modeling_llama.LlamaForCausalLM":
PolicyLocation(file_name="llama", class_name="LlamaForCausalLMPolicy"),
"transformers.models.llama.modeling_llama.LlamaForSequenceClassification":
PolicyLocation(file_name="llama", class_name="LlamaForSequenceClassificationPolicy"),
from transformers import BertLMHeadModel
# T5
from .bert import BertLMHeadModelPolicy
auto_policy_dict[BertLMHeadModel] = BertLMHeadModelPolicy
# GPT2
}
from transformers import BertForMaskedLM
from .bert import BertForMaskedLMPolicy
auto_policy_dict[BertForMaskedLM] = BertForMaskedLMPolicy
def import_policy(policy_location: PolicyLocation) -> Policy:
"""
Dynamically import a Policy class based on the policy location.
"""
module_name = f"colossalai.shardformer.policies.{policy_location.file_name}"
module = importlib.import_module(module_name)
return getattr(module, policy_location.class_name)
from transformers import BertForNextSentencePrediction
from .bert import BertForNextSentencePredictionPolicy
auto_policy_dict[BertForNextSentencePrediction] = BertForNextSentencePredictionPolicy
from transformers import BertForSequenceClassification
from .bert import BertForSequenceClassificationPolicy
auto_policy_dict[BertForSequenceClassification] = BertForSequenceClassificationPolicy
from transformers.models.llama.modeling_llama import LlamaModel
# from .llama import LlamaPolicy
# auto_policy_dict[LlamaModel] = LlamaPolicy
# from transformers import LlamaForSequenceClassification
# from .llama import LlamaForSequenceClassificationPolicy
# auto_policy_dict[LlamaForSequenceClassification] = LlamaForSequenceClassificationPolicy
# from transformers import LlamaForCausalLM
# from .llama import LlamaForCausalLMPolicy
# auto_policy_dict[LlamaForCausalLM] = LlamaForCausalLMPolicy
# from transformers import GPT2Model
# from .gpt2 import GPT2Policy
# auto_policy_dict[GPT2Model] = GPT2Policy
# from transformers import GPT2LMHeadModel
# from .gpt2 import GPT2LMHeadModelPolicy
# auto_policy_dict[GPT2LMHeadModel] = GPT2LMHeadModelPolicy
return auto_policy_dict
def _fullname(obj):
"""
Return the full name of an object, including the module name.
"""
klass = obj.__class__
module = klass.__module__
if module == 'builtins':
return klass.__qualname__ # avoid outputs like 'builtins.str'
return module + '.' + klass.__qualname__
def get_autopolicy(model: nn.Module) -> Policy:
@@ -71,16 +83,14 @@ def get_autopolicy(model: nn.Module) -> Policy:
Return:
:class:`Policy`: The auto policy for the model
"""
auto_policy_dict = build_policies()
policy = auto_policy_dict.get(model.__class__, None)
if policy is None:
full_name = _fullname(model)
policy_location = _POLICY_LIST.get(full_name, None)
if policy_location is None:
raise NotImplementedError(
f"Auto policy for {model.__class__.__qualname__} is not implemented\n Supported models are {[i.__qualname__ for i in auto_policy_dict.keys()]}"
f"Auto policy for {model.__class__.__qualname__} is not implemented\n. Supported models are {list(_POLICY_LIST.keys())}"
)
else:
policy = import_policy(policy_location)
return policy()
return policy()
# from transformers.models.bert.modeling_bert import BertForMaskedLM, BertForPreTraining
# model = BertForPreTraining
# policy = get_autopolicy(model)
# print(policy)

View File

@@ -75,6 +75,7 @@ class Policy(ABC):
"""
def __init__(self) -> None:
self.shard_config = None
self.model = None
self.shard_config = None
@@ -101,6 +102,7 @@ class Policy(ABC):
r"""
Perform some preprocessing of the model, like reshaping the embedding layer
"""
pass
@abstractmethod
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
@@ -135,6 +137,7 @@ class Policy(ABC):
...
}
"""
pass
@abstractmethod
def new_model_class(self) -> Union[Type[nn.Module], None]:
@@ -149,6 +152,7 @@ class Policy(ABC):
return BertModel_
```
"""
pass
@abstractmethod
def postprocess(self) -> nn.Module:
@@ -156,3 +160,4 @@ class Policy(ABC):
Perform some postprocessing of the model, like binding the weight of embedding layer with
the classifier layer
"""
pass

View File

@@ -1,122 +1,121 @@
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Tuple, Type
from typing import Dict, Union
import torch.nn as nn
from transformers import LlamaForCausalLM, LlamaForSequenceClassification
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel
import colossalai.shardformer.layer.layers as col_nn
from colossalai.shardformer.layer.layers import Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D
from .basepolicy import Argument, Col_Layer, Policy, Row_Layer
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
class LlamaPolicy(Policy):
@staticmethod
def argument_policy(config, world_size: int) -> Dict[nn.Module, Argument]:
def preprocess(self):
# Resize embedding
vocab_size = self.model.config.vocab_size
world_size = self.shard_config.tensor_parallel_size
if vocab_size % world_size != 0:
new_vocab_size = vocab_size + world_size - vocab_size % world_size
self.model.resize_token_embeddings(new_vocab_size)
return self.model
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
return {
LlamaDecoderLayer:
Argument(attr_dict={
"self_attn.hidden_size": config.hidden_size // world_size,
"self_attn.num_heads": config.num_attention_heads // world_size,
},
param_funcs=[LlamaPolicy.attn_layer, LlamaPolicy.mlp_layer]),
ModulePolicyDescription(
attribute_replacement={
"self_attn.hidden_size":
self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"self_attn.num_heads":
self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="self_attn.q_proj",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="self_attn.k_proj",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="self_attn.v_proj",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="self_attn.o_proj",
target_module=Linear1D_Row,
),
SubModuleReplacementDescription(
suffix="mlp.gate_proj",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="mlp.up_proj",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="mlp.down_proj",
target_module=Linear1D_Row,
)
],
),
LlamaModel:
Argument(attr_dict={}, param_funcs=[LlamaPolicy.embeddings])
ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="embed_tokens",
target_module=VocabParallelEmbedding1D,
)
])
}
@staticmethod
def attn_layer() -> List:
return [
Col_Layer(
suffix="self_attn.q_proj",
weight="weight",
bias="bias",
replace_layer=col_nn.Linear1D_Col,
),
Col_Layer(
suffix="self_attn.k_proj",
weight="weight",
bias="bias",
replace_layer=col_nn.Linear1D_Col,
),
Col_Layer(
suffix="self_attn.v_proj",
weight="weight",
bias="bias",
replace_layer=col_nn.Linear1D_Col,
),
Row_Layer(
suffix="self_attn.o_proj",
weight="weight",
bias="bias",
replace_layer=col_nn.Linear1D_Row,
)
]
def new_model_class(self):
return None
@staticmethod
def mlp_layer() -> List:
return [
Col_Layer(
suffix="mlp.gate_proj",
weight="weight",
bias="bias",
replace_layer=col_nn.Linear1D_Col,
gather_output=True,
),
Col_Layer(
suffix="mlp.up_proj",
weight="weight",
bias="bias",
replace_layer=col_nn.Linear1D_Row,
gather_output=True,
),
Col_Layer(
suffix="mlp.down_proj",
weight="weight",
bias="bias",
replace_layer=col_nn.Linear1D_Col,
gather_output=True,
),
]
@staticmethod
def embeddings() -> List:
return [Col_Layer(
suffix="embed_tokens",
weight="weight",
replace_layer=col_nn.VocabParallelEmbedding1D,
)]
from transformers import LlamaForCausalLM
def postprocess(self):
return self.model
class LlamaForCausalLMPolicy(LlamaPolicy):
@staticmethod
def argument(config, world_size):
llamapolicy = LlamaPolicy.argument_policy(config, world_size)
argument = {LlamaForCausalLM: Argument(attr_dict={}, param_funcs=[LlamaForCausalLMPolicy.lm_head])}
argument.update(llamapolicy)
@staticmethod
def lm_head() -> List:
return [Col_Layer(suffix="lm_head", weight="weight", replace_layer=col_nn.Linear1D_Col, gather_output=True)]
from transformers import LlamaForSequenceClassification
def module_policy(self):
policy = super().module_policy()
# add a new item for casual lm
new_item = {
LlamaForCausalLM:
ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(suffix="lm_head",
target_module=Linear1D_Col,
kwargs=dict(gather_output=True))
])
}
policy.update(new_item)
return policy
class LlamaForSequenceClassificationPolicy(LlamaPolicy):
@staticmethod
def argument(config, world_size):
llamapolicy = LlamaPolicy.argument_policy(config, world_size)
argument = {
LlamaForSequenceClassification:
Argument(attr_dict={}, param_funcs=[LlamaForSequenceClassificationPolicy.score])
}
argument.update(llamapolicy)
def module_policy(self):
policy = super().module_policy()
@staticmethod
def score() -> List:
return [Col_Layer(suffix="score", weight="weight", replace_layer=col_nn.Linear1D_Col, gather_output=True)]
# add a new item for sequence classification
new_item = {
LlamaForSequenceClassification:
ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(suffix="score",
target_module=Linear1D_Col,
kwargs=dict(gather_output=True))
])
}
policy.update(new_item)
return policy

View File

@@ -1,5 +1,6 @@
from dataclasses import dataclass
from typing import List, Literal
from colossalai.cluster.dist_coordinator import DistCoordinator
__all__ = ['ShardConfig']
@@ -19,9 +20,19 @@ class ShardConfig:
gather_output (bool): Whether to gather the output of the model of the last layer
"""
tensor_parallel_size: int
# TODO: add support for tensor parallel
# pipeline_parallel_size: int
# data_parallel_size: int
tensor_parallel_mode: Literal['1d', '2d', '2.5d', '3d']
inference_only: bool = True
gather_output: bool = True
# tensor_parallel_mode: Literal['1d', '2d', '2.5d', '3d']
# inference_only: bool = True
# gather_output: bool = True
def __post_init__(self):
coordinator = DistCoordinator()
# ensure the parallel size can match the world size
world_size = coordinator.world_size
self.data_parallel_size = world_size // self.tensor_parallel_size
assert world_size == self.data_parallel_size * self.tensor_parallel_size, \
f"The world size ({world_size}) should be divisible by the data parallel size {self.data_parallel_size} and tensor parallel size {self.tensor_parallel_size}"

View File

@@ -1,8 +1,6 @@
from typing import Any, Callable, Dict, List
import torch
import torch.nn as nn
from transformers.pytorch_utils import Conv1D
from colossalai.cluster.process_group_manager import ProcessGroupManager
@@ -41,10 +39,10 @@ class ModelSharder(object):
"""
self.policy.set_model(self.model)
self.policy.set_shard_config(self.shard_config)
self.preprocess()
self.replace_model_class()
self.replace_module()
self.postprocess()
self._preprocess()
self._replace_model_class()
self._replace_module()
self._postprocess()
def reshape_embedding(self) -> None:
r"""
@@ -57,13 +55,13 @@ class ModelSharder(object):
self.model.resize_token_embeddings(new_vocab_size)
self.model_config = self.model.config
def preprocess(self) -> None:
def _preprocess(self) -> None:
self.model = self.policy.preprocess()
def postprocess(self) -> None:
def _postprocess(self) -> None:
self.model = self.policy.postprocess()
def replace_model_class(self) -> None:
def _replace_model_class(self,) -> None:
r"""
Replace the model to policy defined model
Mainly modify the forward and backward to fit distributed model
@@ -84,7 +82,7 @@ class ModelSharder(object):
getattr(new_model_class, key),
)
def replace_module(self) -> None:
def _replace_module(self,) -> None:
r"""
Replace the module according to the policy, and replace the module one by one

View File

@@ -47,10 +47,12 @@ class ShardFormer:
"""
Initialize the distributed process group according to the
"""
# create process group manager and 1d process group
# TODO: may need to support other parallel mode when the config has such as field
pg_manager = ProcessGroupManager()
if (self.shard_config.tensor_parallel_mode == '1d'):
pg_manager.create_process_group(name='tp1d', ranks=range(self.coordinator.world_size))
pg_manager.create_process_group(name='tp1d', ranks=range(self.coordinator.world_size))
self.pg_manager = pg_manager
return pg_manager
def shard_model(self, model: nn.Module, policy: Policy = None):