mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-05-04 06:28:05 +00:00
[shardformer] import huggingface implicitly (#4101)
This commit is contained in:
parent
6a88bae4ec
commit
44a190e6ac
@ -5,6 +5,8 @@ import torch.nn as nn
|
|||||||
|
|
||||||
from .basepolicy import Policy
|
from .basepolicy import Policy
|
||||||
|
|
||||||
|
__all__ = ["PolicyLocation", "get_autopolicy", "import_policy"]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class PolicyLocation:
|
class PolicyLocation:
|
||||||
|
@ -8,6 +8,8 @@ import torch.nn as nn
|
|||||||
|
|
||||||
from ..shard.shard_config import ShardConfig
|
from ..shard.shard_config import ShardConfig
|
||||||
|
|
||||||
|
__all__ = ["ParallelModule", "SubModuleReplacementDescription", "ModulePolicyDescription", "Policy"]
|
||||||
|
|
||||||
|
|
||||||
class ParallelModule():
|
class ParallelModule():
|
||||||
|
|
||||||
|
@ -1,18 +1,16 @@
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from transformers.models.bert.modeling_bert import (
|
|
||||||
BertEmbeddings,
|
|
||||||
BertForMultipleChoice,
|
|
||||||
BertForSequenceClassification,
|
|
||||||
BertForTokenClassification,
|
|
||||||
BertLayer,
|
|
||||||
BertLMPredictionHead,
|
|
||||||
)
|
|
||||||
|
|
||||||
import colossalai.shardformer.layer as col_nn
|
import colossalai.shardformer.layer as col_nn
|
||||||
|
|
||||||
from .._utils import getattr_, setattr_
|
from .._utils import getattr_, setattr_
|
||||||
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'BertPolicy', 'BertModelPolicy', 'BertForPretrainingPolicy', 'BertLMHeadModelPolicy', 'BertForMaskedLMPolicy',
|
||||||
|
'BertForNextSentencePredictionPolicy', 'BertForSequenceClassificationPolicy', 'BertForTokenClassificationPolicy',
|
||||||
|
'BertForMultipleChoicePolicy'
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class BertPolicy(Policy):
|
class BertPolicy(Policy):
|
||||||
|
|
||||||
@ -33,6 +31,8 @@ class BertPolicy(Policy):
|
|||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
def module_policy(self):
|
def module_policy(self):
|
||||||
|
from transformers.models.bert.modeling_bert import BertEmbeddings, BertLayer
|
||||||
|
|
||||||
base_policy = {
|
base_policy = {
|
||||||
BertLayer:
|
BertLayer:
|
||||||
ModulePolicyDescription(
|
ModulePolicyDescription(
|
||||||
@ -123,7 +123,7 @@ class BertPolicy(Policy):
|
|||||||
|
|
||||||
def new_model_class(self):
|
def new_model_class(self):
|
||||||
# do nothing
|
# do nothing
|
||||||
return self.model
|
return None
|
||||||
|
|
||||||
def postprocess(self):
|
def postprocess(self):
|
||||||
return self.model
|
return self.model
|
||||||
@ -143,6 +143,8 @@ class BertForPretrainingPolicy(BertPolicy):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def module_policy(self):
|
def module_policy(self):
|
||||||
|
from transformers.models.bert.modeling_bert import BertLMPredictionHead
|
||||||
|
|
||||||
module_policy = super().module_policy()
|
module_policy = super().module_policy()
|
||||||
addon_module = {
|
addon_module = {
|
||||||
BertLMPredictionHead:
|
BertLMPredictionHead:
|
||||||
@ -184,6 +186,8 @@ class BertLMHeadModelPolicy(BertPolicy):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def module_policy(self):
|
def module_policy(self):
|
||||||
|
from transformers.models.bert.modeling_bert import BertLMPredictionHead
|
||||||
|
|
||||||
module_policy = super().module_policy()
|
module_policy = super().module_policy()
|
||||||
addon_module = {
|
addon_module = {
|
||||||
BertLMPredictionHead:
|
BertLMPredictionHead:
|
||||||
@ -221,6 +225,8 @@ class BertForMaskedLMPolicy(BertPolicy):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def module_policy(self):
|
def module_policy(self):
|
||||||
|
from transformers.models.bert.modeling_bert import BertLMPredictionHead
|
||||||
|
|
||||||
module_policy = super().module_policy()
|
module_policy = super().module_policy()
|
||||||
addon_module = {
|
addon_module = {
|
||||||
BertLMPredictionHead:
|
BertLMPredictionHead:
|
||||||
@ -261,6 +267,8 @@ class BertForSequenceClassificationPolicy(BertPolicy):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def module_policy(self):
|
def module_policy(self):
|
||||||
|
from transformers.models.bert.modeling_bert import BertForSequenceClassification
|
||||||
|
|
||||||
module_policy = super().module_policy()
|
module_policy = super().module_policy()
|
||||||
addon_module = {
|
addon_module = {
|
||||||
BertForSequenceClassification:
|
BertForSequenceClassification:
|
||||||
@ -284,6 +292,8 @@ class BertForTokenClassificationPolicy(BertPolicy):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def module_policy(self):
|
def module_policy(self):
|
||||||
|
from transformers.models.bert.modeling_bert import BertForTokenClassification
|
||||||
|
|
||||||
module_policy = super().module_policy()
|
module_policy = super().module_policy()
|
||||||
addon_module = {
|
addon_module = {
|
||||||
BertForTokenClassification:
|
BertForTokenClassification:
|
||||||
@ -314,6 +324,8 @@ class BertForMultipleChoicePolicy(BertPolicy):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def module_policy(self):
|
def module_policy(self):
|
||||||
|
from transformers.models.bert.modeling_bert import BertForMultipleChoice
|
||||||
|
|
||||||
module_policy = super().module_policy()
|
module_policy = super().module_policy()
|
||||||
addon_module = {
|
addon_module = {
|
||||||
BertForMultipleChoice:
|
BertForMultipleChoice:
|
||||||
|
@ -1,11 +1,15 @@
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from transformers.models.gpt2.modeling_gpt2 import GPT2Block, GPT2DoubleHeadsModel, GPT2LMHeadModel, GPT2Model
|
|
||||||
|
|
||||||
import colossalai.shardformer.layer as col_nn
|
import colossalai.shardformer.layer as col_nn
|
||||||
|
|
||||||
from .._utils import getattr_, setattr_
|
from .._utils import getattr_, setattr_
|
||||||
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'GPT2Policy', 'GPT2ModelPolicy', 'GPT2LMHeadModelPolicy', 'GPT2DoubleHeadsModelPolicy',
|
||||||
|
'GPT2ForTokenClassificationPolicy', 'GPT2ForSequenceClassificationPolicy'
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class GPT2Policy(Policy):
|
class GPT2Policy(Policy):
|
||||||
|
|
||||||
@ -25,7 +29,9 @@ class GPT2Policy(Policy):
|
|||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
def module_policy(self):
|
def module_policy(self):
|
||||||
base_policy = {
|
from transformers.models.gpt2.modeling_gpt2 import GPT2Block, GPT2Model
|
||||||
|
|
||||||
|
return {
|
||||||
GPT2Model:
|
GPT2Model:
|
||||||
ModulePolicyDescription(attribute_replacement={},
|
ModulePolicyDescription(attribute_replacement={},
|
||||||
param_replacement=[],
|
param_replacement=[],
|
||||||
@ -125,6 +131,8 @@ class GPT2LMHeadModelPolicy(GPT2Policy):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def module_policy(self):
|
def module_policy(self):
|
||||||
|
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel
|
||||||
|
|
||||||
module_policy = super().module_policy()
|
module_policy = super().module_policy()
|
||||||
addon_module = {
|
addon_module = {
|
||||||
GPT2LMHeadModel:
|
GPT2LMHeadModel:
|
||||||
@ -156,6 +164,8 @@ class GPT2DoubleHeadsModelPolicy(GPT2Policy):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def module_policy(self):
|
def module_policy(self):
|
||||||
|
from transformers.models.gpt2.modeling_gpt2 import GPT2DoubleHeadsModel
|
||||||
|
|
||||||
module_policy = super().module_policy()
|
module_policy = super().module_policy()
|
||||||
addon_module = {
|
addon_module = {
|
||||||
GPT2DoubleHeadsModel:
|
GPT2DoubleHeadsModel:
|
||||||
|
@ -1,13 +1,13 @@
|
|||||||
from typing import Dict, Union
|
from typing import Dict, Union
|
||||||
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from transformers import LlamaForCausalLM, LlamaForSequenceClassification
|
|
||||||
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel
|
|
||||||
|
|
||||||
from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D
|
from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D
|
||||||
|
|
||||||
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||||
|
|
||||||
|
__all__ = ['LlamaPolicy', 'LlamaForCausalLMPolicy', 'LlamaForSequenceClassificationPolicy']
|
||||||
|
|
||||||
|
|
||||||
class LlamaPolicy(Policy):
|
class LlamaPolicy(Policy):
|
||||||
|
|
||||||
@ -26,7 +26,9 @@ class LlamaPolicy(Policy):
|
|||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
|
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
|
||||||
base_policy = {
|
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel
|
||||||
|
|
||||||
|
return {
|
||||||
LlamaDecoderLayer:
|
LlamaDecoderLayer:
|
||||||
ModulePolicyDescription(
|
ModulePolicyDescription(
|
||||||
attribute_replacement={
|
attribute_replacement={
|
||||||
@ -109,6 +111,8 @@ class LlamaPolicy(Policy):
|
|||||||
class LlamaForCausalLMPolicy(LlamaPolicy):
|
class LlamaForCausalLMPolicy(LlamaPolicy):
|
||||||
|
|
||||||
def module_policy(self):
|
def module_policy(self):
|
||||||
|
from transformers import LlamaForCausalLM
|
||||||
|
|
||||||
policy = super().module_policy()
|
policy = super().module_policy()
|
||||||
# add a new item for casual lm
|
# add a new item for casual lm
|
||||||
new_item = {
|
new_item = {
|
||||||
@ -128,6 +132,8 @@ class LlamaForCausalLMPolicy(LlamaPolicy):
|
|||||||
class LlamaForSequenceClassificationPolicy(LlamaPolicy):
|
class LlamaForSequenceClassificationPolicy(LlamaPolicy):
|
||||||
|
|
||||||
def module_policy(self):
|
def module_policy(self):
|
||||||
|
from transformers import LlamaForSequenceClassification
|
||||||
|
|
||||||
policy = super().module_policy()
|
policy = super().module_policy()
|
||||||
|
|
||||||
# add a new item for sequence classification
|
# add a new item for sequence classification
|
||||||
|
@ -1,15 +1,12 @@
|
|||||||
from transformers.models.opt.modeling_opt import (
|
|
||||||
OPTAttention,
|
|
||||||
OPTDecoder,
|
|
||||||
OPTDecoderLayer,
|
|
||||||
OPTForCausalLM,
|
|
||||||
OPTForSequenceClassification,
|
|
||||||
)
|
|
||||||
|
|
||||||
from colossalai.shardformer.layer import Embedding1D, FusedLayerNorm, Linear1D_Col, Linear1D_Row
|
from colossalai.shardformer.layer import Embedding1D, FusedLayerNorm, Linear1D_Col, Linear1D_Row
|
||||||
|
|
||||||
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'OPTPolicy', 'OPTModelPolicy', 'OPTForCausalLMPolicy', 'OPTForSequenceClassificationPolicy',
|
||||||
|
'OPTForQuestionAnsweringPolicy'
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class OPTPolicy(Policy):
|
class OPTPolicy(Policy):
|
||||||
|
|
||||||
@ -29,6 +26,8 @@ class OPTPolicy(Policy):
|
|||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
def module_policy(self):
|
def module_policy(self):
|
||||||
|
from transformers.models.opt.modeling_opt import OPTAttention, OPTDecoder, OPTDecoderLayer
|
||||||
|
|
||||||
base_policy = {
|
base_policy = {
|
||||||
OPTDecoder:
|
OPTDecoder:
|
||||||
ModulePolicyDescription(attribute_replacement={},
|
ModulePolicyDescription(attribute_replacement={},
|
||||||
@ -111,6 +110,8 @@ class OPTModelPolicy(OPTPolicy):
|
|||||||
class OPTForCausalLMPolicy(OPTPolicy):
|
class OPTForCausalLMPolicy(OPTPolicy):
|
||||||
|
|
||||||
def module_policy(self):
|
def module_policy(self):
|
||||||
|
from transformers.models.opt.modeling_opt import OPTForCausalLM
|
||||||
|
|
||||||
policy = super().module_policy()
|
policy = super().module_policy()
|
||||||
new_item = {
|
new_item = {
|
||||||
OPTForCausalLM:
|
OPTForCausalLM:
|
||||||
|
@ -1,15 +1,4 @@
|
|||||||
from transformers import T5ForConditionalGeneration
|
from colossalai.shardformer.layer import DropoutForParallelInput, Embedding1D, Linear1D_Col, Linear1D_Row
|
||||||
from transformers.models.t5.modeling_t5 import (
|
|
||||||
T5Attention,
|
|
||||||
T5DenseActDense,
|
|
||||||
T5DenseGatedActDense,
|
|
||||||
T5LayerCrossAttention,
|
|
||||||
T5LayerFF,
|
|
||||||
T5LayerSelfAttention,
|
|
||||||
T5Stack,
|
|
||||||
)
|
|
||||||
|
|
||||||
from colossalai.shardformer.layer import DropoutForParallelInput, Embedding1D, FusedRMSNorm, Linear1D_Col, Linear1D_Row
|
|
||||||
|
|
||||||
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||||
|
|
||||||
@ -34,7 +23,17 @@ class T5ModelPolicy(Policy):
|
|||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
def module_policy(self):
|
def module_policy(self):
|
||||||
base_policy = {
|
from transformers.models.t5.modeling_t5 import (
|
||||||
|
T5Attention,
|
||||||
|
T5DenseActDense,
|
||||||
|
T5DenseGatedActDense,
|
||||||
|
T5LayerCrossAttention,
|
||||||
|
T5LayerFF,
|
||||||
|
T5LayerSelfAttention,
|
||||||
|
T5Stack,
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
T5Stack:
|
T5Stack:
|
||||||
ModulePolicyDescription(attribute_replacement={},
|
ModulePolicyDescription(attribute_replacement={},
|
||||||
param_replacement=[],
|
param_replacement=[],
|
||||||
@ -165,6 +164,8 @@ class T5ModelPolicy(Policy):
|
|||||||
class T5ForConditionalGenerationPolicy(T5ModelPolicy):
|
class T5ForConditionalGenerationPolicy(T5ModelPolicy):
|
||||||
|
|
||||||
def module_policy(self):
|
def module_policy(self):
|
||||||
|
from transformers import T5ForConditionalGeneration
|
||||||
|
|
||||||
policy = super().module_policy()
|
policy = super().module_policy()
|
||||||
|
|
||||||
new_item = {
|
new_item = {
|
||||||
|
@ -1,12 +1,13 @@
|
|||||||
from typing import Dict, Union
|
from typing import Dict, Union
|
||||||
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from transformers.models.vit.modeling_vit import ViTAttention, ViTEmbeddings, ViTLayer, ViTModel
|
|
||||||
|
|
||||||
from colossalai.shardformer.layer import DropoutForReplicatedInput, FusedLayerNorm, Linear1D_Col, Linear1D_Row
|
from colossalai.shardformer.layer import DropoutForReplicatedInput, FusedLayerNorm, Linear1D_Col, Linear1D_Row
|
||||||
|
|
||||||
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||||
|
|
||||||
|
__all__ = ['ViTPolicy']
|
||||||
|
|
||||||
|
|
||||||
class ViTPolicy(Policy):
|
class ViTPolicy(Policy):
|
||||||
|
|
||||||
@ -25,7 +26,9 @@ class ViTPolicy(Policy):
|
|||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
|
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
|
||||||
base_policy = {
|
from transformers.models.vit.modeling_vit import ViTEmbeddings, ViTLayer
|
||||||
|
|
||||||
|
return {
|
||||||
ViTEmbeddings:
|
ViTEmbeddings:
|
||||||
ModulePolicyDescription(attribute_replacement={},
|
ModulePolicyDescription(attribute_replacement={},
|
||||||
param_replacement=[],
|
param_replacement=[],
|
||||||
|
@ -19,6 +19,7 @@ class ShardConfig:
|
|||||||
"""
|
"""
|
||||||
tensor_parallel_process_group: int = None
|
tensor_parallel_process_group: int = None
|
||||||
enable_fused_normalization: bool = False
|
enable_fused_normalization: bool = False
|
||||||
|
enable_all_optimization: bool = False
|
||||||
|
|
||||||
# TODO: add support for tensor parallel
|
# TODO: add support for tensor parallel
|
||||||
# pipeline_parallel_size: int
|
# pipeline_parallel_size: int
|
||||||
@ -27,6 +28,21 @@ class ShardConfig:
|
|||||||
# inference_only: bool = True
|
# inference_only: bool = True
|
||||||
# gather_output: bool = True
|
# gather_output: bool = True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def tensor_parallel_size(self):
|
||||||
|
return self._tensor_parallel_size
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
# get the parallel size
|
# get the parallel size
|
||||||
self.tensor_parallel_size = dist.get_world_size(self.tensor_parallel_process_group)
|
self._tensor_parallel_size = dist.get_world_size(self.tensor_parallel_process_group)
|
||||||
|
|
||||||
|
# turn on all optimization if all_optimization is set to True
|
||||||
|
if self.enable_all_optimization:
|
||||||
|
self._turn_on_all_optimization()
|
||||||
|
|
||||||
|
def _turn_on_all_optimization(self):
|
||||||
|
"""
|
||||||
|
Turn on all optimization.
|
||||||
|
"""
|
||||||
|
# you can add all the optimization flag here
|
||||||
|
self.fused_layernorm = True
|
||||||
|
Loading…
Reference in New Issue
Block a user