diff --git a/colossalai/shardformer/policies/basepolicy.py b/colossalai/shardformer/policies/basepolicy.py index 80ea7a252..baae95980 100644 --- a/colossalai/shardformer/policies/basepolicy.py +++ b/colossalai/shardformer/policies/basepolicy.py @@ -76,6 +76,7 @@ class Policy(ABC): def __init__(self) -> None: self.model = None + self.shard_config = None def set_model(self, model: nn.Module) -> None: r""" @@ -86,14 +87,23 @@ class Policy(ABC): """ self.model = model + def set_shard_config(self, shard_config: ShardConfig) -> None: + r""" + Set shard config as an attribute of the Policy object. + + Args: + shard_config (:class:`ShardConfig`): The shard config to be perform + """ + self.shard_config = shard_config + @abstractmethod - def preprocess(self, shard_config: ShardConfig = None) -> nn.Module: + def preprocess(self) -> nn.Module: r""" Perform some preprocessing of the model, like reshaping the embedding layer """ @abstractmethod - def module_policy(self, shard_config: ShardConfig = None) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: + def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: r""" Return the dict for the modify policy, the key is the original layer class and the value is the argument for the modify layer diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index fe74f83ca..06ee9b435 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -4,41 +4,40 @@ from transformers.models.bert.modeling_bert import BertEmbeddings, BertLayer, Be import colossalai.shardformer.layer.layers as col_nn from colossalai.shardformer.layer.dropout import Dropout1D -from ..shard.shard_config import ShardConfig from ..utils import getattr_, setattr_ from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription class BertPolicy(Policy): - def preprocess(self, shard_config: ShardConfig = None): + def preprocess(self): # reshape the embedding layer r""" Reshape the Embedding layer to make the embedding dimension divisible by world_size """ # TODO: vocab_size = self.model.config.vocab_size - world_size = shard_config.tensor_parallel_size + world_size = self.shard_config.tensor_parallel_size if vocab_size % world_size != 0: new_vocab_size = vocab_size + world_size - vocab_size % world_size self.model.resize_token_embeddings(new_vocab_size) return self.model - def module_policy(self, shard_config: ShardConfig = None): + def module_policy(self): return { BertLayer: ModulePolicyDescription( attribute_replacement={ # 1. shard hidden size "attention.self.all_head_size": - self.model.config.hidden_size // shard_config.tensor_parallel_size, + self.model.config.hidden_size // self.shard_config.tensor_parallel_size, "crossattention.self.all_head_size": - self.model.config.hidden_size // shard_config.tensor_parallel_size, + self.model.config.hidden_size // self.shard_config.tensor_parallel_size, # 2. shard number of heads "attention.self.num_attention_heads": - self.model.config.num_attention_heads // shard_config.tensor_parallel_size, + self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, "crossattention.self.num_attention_heads": - self.model.config.num_attention_heads // shard_config.tensor_parallel_size, + self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, }, param_replacement=[], sub_module_replacement=[ @@ -100,13 +99,43 @@ class BertPolicy(Policy): return self.model +# BertModel +class BertModelPolicy(BertPolicy): + + def __init__(self) -> None: + super().__init__() + + +# BertForPreTraining +class BertForPretrainingPolicy(BertPolicy): + + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + module_policy = super().module_policy() + addon_module = { + BertLMPredictionHead: + ModulePolicyDescription(attribute_replacement={}, + param_replacement=[], + sub_module_replacement=[ + SubModuleReplacementDescription(suffix="decoder", + target_module=col_nn.Linear1D_Col, + kwargs={"gather_output": True}) + ]) + } + module_policy.update(addon_module) + return module_policy + + +# BertForMaskedLM class BertForMaskedLMPolicy(BertPolicy): def __init__(self) -> None: super().__init__() - def module_policy(self, shard_config: ShardConfig = None): - module_policy = super().module_policy(shard_config) + def module_policy(self): + module_policy = super().module_policy() addon_module = { BertLMPredictionHead: ModulePolicyDescription(attribute_replacement={}, @@ -124,16 +153,41 @@ class BertForMaskedLMPolicy(BertPolicy): # BertLMHeadModel class BertLMHeadModelPolicy(BertPolicy): - @staticmethod - def argument_policy(config, world_size): - base_argument = BertPolicy.argument_policy(config, world_size) - argument = { - BertLMPredictionHead: Argument(attr_dict={}, param_funcs=[ - BertPolicy.unembedding, - ]), + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + module_policy = super().module_policy() + addon_module = { + BertLMPredictionHead: + ModulePolicyDescription(attribute_replacement={}, + param_replacement=[], + sub_module_replacement=[ + SubModuleReplacementDescription(suffix="decoder", + target_module=col_nn.Linear1D_Col, + kwargs={"gather_output": True}) + ]) } - argument.update(base_argument) - return argument + module_policy.update(addon_module) + return module_policy + + +# BertForNextSentencePrediction +class BertForNextSentencePredictionPolicy(BertPolicy): + + def __init__(self) -> None: + super().__init__() + + +# BertForSequenceClassification +class BertForSequenceClassificationPolicy(BertPolicy): + + def __init__(self) -> None: + super().__init__() + + +# BertForMultipleChoice +class BertForMultipleChoicePolicy(BertPolicy): def __init__(self) -> None: super().__init__() diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index 53999529d..670a5775d 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -18,10 +18,10 @@ class ShardConfig: will not calculate the loss and just return the output. gather_output (bool): Whether to gather the output of the model of the last layer """ - data_parallel_size: int tensor_parallel_size: int - - pipeline_parallel_size: int + # TODO: add support for tensor parallel + # pipeline_parallel_size: int + # data_parallel_size: int tensor_parallel_mode: Literal['1d', '2d', '2.5d', '3d'] inference_only: bool = True gather_output: bool = True diff --git a/colossalai/shardformer/shard/sharder.py b/colossalai/shardformer/shard/sharder.py index 5c8584595..b90e79059 100644 --- a/colossalai/shardformer/shard/sharder.py +++ b/colossalai/shardformer/shard/sharder.py @@ -40,6 +40,7 @@ class ModelSharder(object): Shard the model according to the policy """ self.policy.set_model(self.model) + self.policy.set_shard_config(self.shard_config) self.preprocess() self.replace_model_class() self.replace_module() @@ -57,12 +58,12 @@ class ModelSharder(object): self.model_config = self.model.config def preprocess(self) -> None: - self.model = self.policy.preprocess(self.shard_config) + self.model = self.policy.preprocess() def postprocess(self) -> None: self.model = self.policy.postprocess() - def replace_model_class(self,) -> None: + def replace_model_class(self) -> None: r""" Replace the model to policy defined model Mainly modify the forward and backward to fit distributed model @@ -83,14 +84,14 @@ class ModelSharder(object): getattr(new_model_class, key), ) - def replace_module(self,) -> None: + def replace_module(self) -> None: r""" Replace the module according to the policy, and replace the module one by one Args: model (:class:`torch.nn.Module`): The model to shard """ - module_descriptions = self.policy.module_policy(self.shard_config) + module_descriptions = self.policy.module_policy() for module_description in module_descriptions.items(): origin_layer_cls = module_description[0] attr_replacement = module_description[1].attribute_replacement diff --git a/colossalai/shardformer/shard/shardformer.py b/colossalai/shardformer/shard/shardformer.py index 5313dfecb..954bdaa82 100644 --- a/colossalai/shardformer/shard/shardformer.py +++ b/colossalai/shardformer/shard/shardformer.py @@ -25,11 +25,7 @@ class ShardFormer: org_model = BertForMaskedLM.from_pretrained('bert-base-uncased') shard_config = ShardConfig( tensor_parallel_size=2, - data_parallel_size=1, - pipeline_parallel_size=1, tensor_parallel_mode='1d', - inference_only=True, - gather_output=True ) shard_former = ShardFormer(shard_config=shard_config) shard_former.init_distributed() diff --git a/tests/test_shardformer/test_model/test_shard_bert.py b/tests/test_shardformer/test_model/test_shard_bert.py index 05d033436..0dd0fdeee 100644 --- a/tests/test_shardformer/test_model/test_shard_bert.py +++ b/tests/test_shardformer/test_model/test_shard_bert.py @@ -7,7 +7,6 @@ from transformers import ( AutoTokenizer, BertConfig, BertForMaskedLM, - BertForMultipleChoice, BertForNextSentencePrediction, BertForPreTraining, BertForSequenceClassification, @@ -36,12 +35,10 @@ def build_model(rank, world_size, model): org_model.to('cuda') # TODO: no need to transfer to cuda org_model_forshard.to('cuda') - shard_config = ShardConfig(tensor_parallel_size=2, - data_parallel_size=1, - pipeline_parallel_size=1, - tensor_parallel_mode='1d', - inference_only=True, - gather_output=True) + shard_config = ShardConfig( + tensor_parallel_size=2, + tensor_parallel_mode='1d', + ) shard_former = ShardFormer(shard_config=shard_config) shard_former.init_distributed() sharded_model = shard_former.shard_model(org_model_forshard).to('cuda')