mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-27 04:33:04 +00:00
[shardformer] import huggingface implicitly (#4101)
This commit is contained in:
@@ -1,18 +1,16 @@
|
||||
import torch.nn as nn
|
||||
from transformers.models.bert.modeling_bert import (
|
||||
BertEmbeddings,
|
||||
BertForMultipleChoice,
|
||||
BertForSequenceClassification,
|
||||
BertForTokenClassification,
|
||||
BertLayer,
|
||||
BertLMPredictionHead,
|
||||
)
|
||||
|
||||
import colossalai.shardformer.layer as col_nn
|
||||
|
||||
from .._utils import getattr_, setattr_
|
||||
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||
|
||||
__all__ = [
|
||||
'BertPolicy', 'BertModelPolicy', 'BertForPretrainingPolicy', 'BertLMHeadModelPolicy', 'BertForMaskedLMPolicy',
|
||||
'BertForNextSentencePredictionPolicy', 'BertForSequenceClassificationPolicy', 'BertForTokenClassificationPolicy',
|
||||
'BertForMultipleChoicePolicy'
|
||||
]
|
||||
|
||||
|
||||
class BertPolicy(Policy):
|
||||
|
||||
@@ -33,6 +31,8 @@ class BertPolicy(Policy):
|
||||
return self.model
|
||||
|
||||
def module_policy(self):
|
||||
from transformers.models.bert.modeling_bert import BertEmbeddings, BertLayer
|
||||
|
||||
base_policy = {
|
||||
BertLayer:
|
||||
ModulePolicyDescription(
|
||||
@@ -123,7 +123,7 @@ class BertPolicy(Policy):
|
||||
|
||||
def new_model_class(self):
|
||||
# do nothing
|
||||
return self.model
|
||||
return None
|
||||
|
||||
def postprocess(self):
|
||||
return self.model
|
||||
@@ -143,6 +143,8 @@ class BertForPretrainingPolicy(BertPolicy):
|
||||
super().__init__()
|
||||
|
||||
def module_policy(self):
|
||||
from transformers.models.bert.modeling_bert import BertLMPredictionHead
|
||||
|
||||
module_policy = super().module_policy()
|
||||
addon_module = {
|
||||
BertLMPredictionHead:
|
||||
@@ -184,6 +186,8 @@ class BertLMHeadModelPolicy(BertPolicy):
|
||||
super().__init__()
|
||||
|
||||
def module_policy(self):
|
||||
from transformers.models.bert.modeling_bert import BertLMPredictionHead
|
||||
|
||||
module_policy = super().module_policy()
|
||||
addon_module = {
|
||||
BertLMPredictionHead:
|
||||
@@ -221,6 +225,8 @@ class BertForMaskedLMPolicy(BertPolicy):
|
||||
super().__init__()
|
||||
|
||||
def module_policy(self):
|
||||
from transformers.models.bert.modeling_bert import BertLMPredictionHead
|
||||
|
||||
module_policy = super().module_policy()
|
||||
addon_module = {
|
||||
BertLMPredictionHead:
|
||||
@@ -261,6 +267,8 @@ class BertForSequenceClassificationPolicy(BertPolicy):
|
||||
super().__init__()
|
||||
|
||||
def module_policy(self):
|
||||
from transformers.models.bert.modeling_bert import BertForSequenceClassification
|
||||
|
||||
module_policy = super().module_policy()
|
||||
addon_module = {
|
||||
BertForSequenceClassification:
|
||||
@@ -284,6 +292,8 @@ class BertForTokenClassificationPolicy(BertPolicy):
|
||||
super().__init__()
|
||||
|
||||
def module_policy(self):
|
||||
from transformers.models.bert.modeling_bert import BertForTokenClassification
|
||||
|
||||
module_policy = super().module_policy()
|
||||
addon_module = {
|
||||
BertForTokenClassification:
|
||||
@@ -314,6 +324,8 @@ class BertForMultipleChoicePolicy(BertPolicy):
|
||||
super().__init__()
|
||||
|
||||
def module_policy(self):
|
||||
from transformers.models.bert.modeling_bert import BertForMultipleChoice
|
||||
|
||||
module_policy = super().module_policy()
|
||||
addon_module = {
|
||||
BertForMultipleChoice:
|
||||
|
Reference in New Issue
Block a user