[shardformer] import huggingface implicitly (#4101)

This commit is contained in:
Frank Lee
2023-06-30 10:56:29 +08:00
parent 6a88bae4ec
commit 44a190e6ac
9 changed files with 91 additions and 38 deletions

View File

@@ -1,18 +1,16 @@
import torch.nn as nn
from transformers.models.bert.modeling_bert import (
BertEmbeddings,
BertForMultipleChoice,
BertForSequenceClassification,
BertForTokenClassification,
BertLayer,
BertLMPredictionHead,
)
import colossalai.shardformer.layer as col_nn
from .._utils import getattr_, setattr_
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = [
'BertPolicy', 'BertModelPolicy', 'BertForPretrainingPolicy', 'BertLMHeadModelPolicy', 'BertForMaskedLMPolicy',
'BertForNextSentencePredictionPolicy', 'BertForSequenceClassificationPolicy', 'BertForTokenClassificationPolicy',
'BertForMultipleChoicePolicy'
]
class BertPolicy(Policy):
@@ -33,6 +31,8 @@ class BertPolicy(Policy):
return self.model
def module_policy(self):
from transformers.models.bert.modeling_bert import BertEmbeddings, BertLayer
base_policy = {
BertLayer:
ModulePolicyDescription(
@@ -123,7 +123,7 @@ class BertPolicy(Policy):
def new_model_class(self):
# do nothing
return self.model
return None
def postprocess(self):
return self.model
@@ -143,6 +143,8 @@ class BertForPretrainingPolicy(BertPolicy):
super().__init__()
def module_policy(self):
from transformers.models.bert.modeling_bert import BertLMPredictionHead
module_policy = super().module_policy()
addon_module = {
BertLMPredictionHead:
@@ -184,6 +186,8 @@ class BertLMHeadModelPolicy(BertPolicy):
super().__init__()
def module_policy(self):
from transformers.models.bert.modeling_bert import BertLMPredictionHead
module_policy = super().module_policy()
addon_module = {
BertLMPredictionHead:
@@ -221,6 +225,8 @@ class BertForMaskedLMPolicy(BertPolicy):
super().__init__()
def module_policy(self):
from transformers.models.bert.modeling_bert import BertLMPredictionHead
module_policy = super().module_policy()
addon_module = {
BertLMPredictionHead:
@@ -261,6 +267,8 @@ class BertForSequenceClassificationPolicy(BertPolicy):
super().__init__()
def module_policy(self):
from transformers.models.bert.modeling_bert import BertForSequenceClassification
module_policy = super().module_policy()
addon_module = {
BertForSequenceClassification:
@@ -284,6 +292,8 @@ class BertForTokenClassificationPolicy(BertPolicy):
super().__init__()
def module_policy(self):
from transformers.models.bert.modeling_bert import BertForTokenClassification
module_policy = super().module_policy()
addon_module = {
BertForTokenClassification:
@@ -314,6 +324,8 @@ class BertForMultipleChoicePolicy(BertPolicy):
super().__init__()
def module_policy(self):
from transformers.models.bert.modeling_bert import BertForMultipleChoice
module_policy = super().module_policy()
addon_module = {
BertForMultipleChoice: