diff --git a/colossalai/shardformer/policies/autopolicy.py b/colossalai/shardformer/policies/autopolicy.py index 17c063c8d..8051433e8 100644 --- a/colossalai/shardformer/policies/autopolicy.py +++ b/colossalai/shardformer/policies/autopolicy.py @@ -5,6 +5,8 @@ import torch.nn as nn from .basepolicy import Policy +__all__ = ["PolicyLocation", "get_autopolicy", "import_policy"] + @dataclass class PolicyLocation: diff --git a/colossalai/shardformer/policies/basepolicy.py b/colossalai/shardformer/policies/basepolicy.py index 8835e38cb..2b9726069 100644 --- a/colossalai/shardformer/policies/basepolicy.py +++ b/colossalai/shardformer/policies/basepolicy.py @@ -8,6 +8,8 @@ import torch.nn as nn from ..shard.shard_config import ShardConfig +__all__ = ["ParallelModule", "SubModuleReplacementDescription", "ModulePolicyDescription", "Policy"] + class ParallelModule(): diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index 545669f1f..cec7f0eb2 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -1,18 +1,16 @@ import torch.nn as nn -from transformers.models.bert.modeling_bert import ( - BertEmbeddings, - BertForMultipleChoice, - BertForSequenceClassification, - BertForTokenClassification, - BertLayer, - BertLMPredictionHead, -) import colossalai.shardformer.layer as col_nn from .._utils import getattr_, setattr_ from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription +__all__ = [ + 'BertPolicy', 'BertModelPolicy', 'BertForPretrainingPolicy', 'BertLMHeadModelPolicy', 'BertForMaskedLMPolicy', + 'BertForNextSentencePredictionPolicy', 'BertForSequenceClassificationPolicy', 'BertForTokenClassificationPolicy', + 'BertForMultipleChoicePolicy' +] + class BertPolicy(Policy): @@ -33,6 +31,8 @@ class BertPolicy(Policy): return self.model def module_policy(self): + from transformers.models.bert.modeling_bert import BertEmbeddings, BertLayer + base_policy = { BertLayer: ModulePolicyDescription( @@ -123,7 +123,7 @@ class BertPolicy(Policy): def new_model_class(self): # do nothing - return self.model + return None def postprocess(self): return self.model @@ -143,6 +143,8 @@ class BertForPretrainingPolicy(BertPolicy): super().__init__() def module_policy(self): + from transformers.models.bert.modeling_bert import BertLMPredictionHead + module_policy = super().module_policy() addon_module = { BertLMPredictionHead: @@ -184,6 +186,8 @@ class BertLMHeadModelPolicy(BertPolicy): super().__init__() def module_policy(self): + from transformers.models.bert.modeling_bert import BertLMPredictionHead + module_policy = super().module_policy() addon_module = { BertLMPredictionHead: @@ -221,6 +225,8 @@ class BertForMaskedLMPolicy(BertPolicy): super().__init__() def module_policy(self): + from transformers.models.bert.modeling_bert import BertLMPredictionHead + module_policy = super().module_policy() addon_module = { BertLMPredictionHead: @@ -261,6 +267,8 @@ class BertForSequenceClassificationPolicy(BertPolicy): super().__init__() def module_policy(self): + from transformers.models.bert.modeling_bert import BertForSequenceClassification + module_policy = super().module_policy() addon_module = { BertForSequenceClassification: @@ -284,6 +292,8 @@ class BertForTokenClassificationPolicy(BertPolicy): super().__init__() def module_policy(self): + from transformers.models.bert.modeling_bert import BertForTokenClassification + module_policy = super().module_policy() addon_module = { BertForTokenClassification: @@ -314,6 +324,8 @@ class BertForMultipleChoicePolicy(BertPolicy): super().__init__() def module_policy(self): + from transformers.models.bert.modeling_bert import BertForMultipleChoice + module_policy = super().module_policy() addon_module = { BertForMultipleChoice: diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index 3d6d94b8e..c6108f5c0 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -1,11 +1,15 @@ import torch.nn as nn -from transformers.models.gpt2.modeling_gpt2 import GPT2Block, GPT2DoubleHeadsModel, GPT2LMHeadModel, GPT2Model import colossalai.shardformer.layer as col_nn from .._utils import getattr_, setattr_ from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription +__all__ = [ + 'GPT2Policy', 'GPT2ModelPolicy', 'GPT2LMHeadModelPolicy', 'GPT2DoubleHeadsModelPolicy', + 'GPT2ForTokenClassificationPolicy', 'GPT2ForSequenceClassificationPolicy' +] + class GPT2Policy(Policy): @@ -25,7 +29,9 @@ class GPT2Policy(Policy): return self.model def module_policy(self): - base_policy = { + from transformers.models.gpt2.modeling_gpt2 import GPT2Block, GPT2Model + + return { GPT2Model: ModulePolicyDescription(attribute_replacement={}, param_replacement=[], @@ -125,6 +131,8 @@ class GPT2LMHeadModelPolicy(GPT2Policy): super().__init__() def module_policy(self): + from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel + module_policy = super().module_policy() addon_module = { GPT2LMHeadModel: @@ -156,6 +164,8 @@ class GPT2DoubleHeadsModelPolicy(GPT2Policy): super().__init__() def module_policy(self): + from transformers.models.gpt2.modeling_gpt2 import GPT2DoubleHeadsModel + module_policy = super().module_policy() addon_module = { GPT2DoubleHeadsModel: diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index b36180ce3..2fd2bc223 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -1,13 +1,13 @@ from typing import Dict, Union import torch.nn as nn -from transformers import LlamaForCausalLM, LlamaForSequenceClassification -from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription +__all__ = ['LlamaPolicy', 'LlamaForCausalLMPolicy', 'LlamaForSequenceClassificationPolicy'] + class LlamaPolicy(Policy): @@ -26,7 +26,9 @@ class LlamaPolicy(Policy): return self.model def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: - base_policy = { + from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel + + return { LlamaDecoderLayer: ModulePolicyDescription( attribute_replacement={ @@ -109,6 +111,8 @@ class LlamaPolicy(Policy): class LlamaForCausalLMPolicy(LlamaPolicy): def module_policy(self): + from transformers import LlamaForCausalLM + policy = super().module_policy() # add a new item for casual lm new_item = { @@ -128,6 +132,8 @@ class LlamaForCausalLMPolicy(LlamaPolicy): class LlamaForSequenceClassificationPolicy(LlamaPolicy): def module_policy(self): + from transformers import LlamaForSequenceClassification + policy = super().module_policy() # add a new item for sequence classification diff --git a/colossalai/shardformer/policies/opt.py b/colossalai/shardformer/policies/opt.py index ce3873954..ec1bae208 100644 --- a/colossalai/shardformer/policies/opt.py +++ b/colossalai/shardformer/policies/opt.py @@ -1,15 +1,12 @@ -from transformers.models.opt.modeling_opt import ( - OPTAttention, - OPTDecoder, - OPTDecoderLayer, - OPTForCausalLM, - OPTForSequenceClassification, -) - from colossalai.shardformer.layer import Embedding1D, FusedLayerNorm, Linear1D_Col, Linear1D_Row from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription +__all__ = [ + 'OPTPolicy', 'OPTModelPolicy', 'OPTForCausalLMPolicy', 'OPTForSequenceClassificationPolicy', + 'OPTForQuestionAnsweringPolicy' +] + class OPTPolicy(Policy): @@ -29,6 +26,8 @@ class OPTPolicy(Policy): return self.model def module_policy(self): + from transformers.models.opt.modeling_opt import OPTAttention, OPTDecoder, OPTDecoderLayer + base_policy = { OPTDecoder: ModulePolicyDescription(attribute_replacement={}, @@ -111,6 +110,8 @@ class OPTModelPolicy(OPTPolicy): class OPTForCausalLMPolicy(OPTPolicy): def module_policy(self): + from transformers.models.opt.modeling_opt import OPTForCausalLM + policy = super().module_policy() new_item = { OPTForCausalLM: diff --git a/colossalai/shardformer/policies/t5.py b/colossalai/shardformer/policies/t5.py index d35f688a0..845bfe727 100644 --- a/colossalai/shardformer/policies/t5.py +++ b/colossalai/shardformer/policies/t5.py @@ -1,15 +1,4 @@ -from transformers import T5ForConditionalGeneration -from transformers.models.t5.modeling_t5 import ( - T5Attention, - T5DenseActDense, - T5DenseGatedActDense, - T5LayerCrossAttention, - T5LayerFF, - T5LayerSelfAttention, - T5Stack, -) - -from colossalai.shardformer.layer import DropoutForParallelInput, Embedding1D, FusedRMSNorm, Linear1D_Col, Linear1D_Row +from colossalai.shardformer.layer import DropoutForParallelInput, Embedding1D, Linear1D_Col, Linear1D_Row from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription @@ -34,7 +23,17 @@ class T5ModelPolicy(Policy): return self.model def module_policy(self): - base_policy = { + from transformers.models.t5.modeling_t5 import ( + T5Attention, + T5DenseActDense, + T5DenseGatedActDense, + T5LayerCrossAttention, + T5LayerFF, + T5LayerSelfAttention, + T5Stack, + ) + + return { T5Stack: ModulePolicyDescription(attribute_replacement={}, param_replacement=[], @@ -165,6 +164,8 @@ class T5ModelPolicy(Policy): class T5ForConditionalGenerationPolicy(T5ModelPolicy): def module_policy(self): + from transformers import T5ForConditionalGeneration + policy = super().module_policy() new_item = { diff --git a/colossalai/shardformer/policies/vit.py b/colossalai/shardformer/policies/vit.py index 5d8a235db..6a404c2fa 100644 --- a/colossalai/shardformer/policies/vit.py +++ b/colossalai/shardformer/policies/vit.py @@ -1,12 +1,13 @@ from typing import Dict, Union import torch.nn as nn -from transformers.models.vit.modeling_vit import ViTAttention, ViTEmbeddings, ViTLayer, ViTModel from colossalai.shardformer.layer import DropoutForReplicatedInput, FusedLayerNorm, Linear1D_Col, Linear1D_Row from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription +__all__ = ['ViTPolicy'] + class ViTPolicy(Policy): @@ -25,7 +26,9 @@ class ViTPolicy(Policy): return self.model def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: - base_policy = { + from transformers.models.vit.modeling_vit import ViTEmbeddings, ViTLayer + + return { ViTEmbeddings: ModulePolicyDescription(attribute_replacement={}, param_replacement=[], diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index e83191210..2116d2e62 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -19,6 +19,7 @@ class ShardConfig: """ tensor_parallel_process_group: int = None enable_fused_normalization: bool = False + enable_all_optimization: bool = False # TODO: add support for tensor parallel # pipeline_parallel_size: int @@ -27,6 +28,21 @@ class ShardConfig: # inference_only: bool = True # gather_output: bool = True + @property + def tensor_parallel_size(self): + return self._tensor_parallel_size + def __post_init__(self): # get the parallel size - self.tensor_parallel_size = dist.get_world_size(self.tensor_parallel_process_group) + self._tensor_parallel_size = dist.get_world_size(self.tensor_parallel_process_group) + + # turn on all optimization if all_optimization is set to True + if self.enable_all_optimization: + self._turn_on_all_optimization() + + def _turn_on_all_optimization(self): + """ + Turn on all optimization. + """ + # you can add all the optimization flag here + self.fused_layernorm = True