[shardformer] import huggingface implicitly (#4101)

This commit is contained in:
Frank Lee
2023-06-30 10:56:29 +08:00
parent 6a88bae4ec
commit 44a190e6ac
9 changed files with 91 additions and 38 deletions

View File

@@ -1,11 +1,15 @@
import torch.nn as nn
from transformers.models.gpt2.modeling_gpt2 import GPT2Block, GPT2DoubleHeadsModel, GPT2LMHeadModel, GPT2Model
import colossalai.shardformer.layer as col_nn
from .._utils import getattr_, setattr_
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = [
'GPT2Policy', 'GPT2ModelPolicy', 'GPT2LMHeadModelPolicy', 'GPT2DoubleHeadsModelPolicy',
'GPT2ForTokenClassificationPolicy', 'GPT2ForSequenceClassificationPolicy'
]
class GPT2Policy(Policy):
@@ -25,7 +29,9 @@ class GPT2Policy(Policy):
return self.model
def module_policy(self):
base_policy = {
from transformers.models.gpt2.modeling_gpt2 import GPT2Block, GPT2Model
return {
GPT2Model:
ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
@@ -125,6 +131,8 @@ class GPT2LMHeadModelPolicy(GPT2Policy):
super().__init__()
def module_policy(self):
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel
module_policy = super().module_policy()
addon_module = {
GPT2LMHeadModel:
@@ -156,6 +164,8 @@ class GPT2DoubleHeadsModelPolicy(GPT2Policy):
super().__init__()
def module_policy(self):
from transformers.models.gpt2.modeling_gpt2 import GPT2DoubleHeadsModel
module_policy = super().module_policy()
addon_module = {
GPT2DoubleHeadsModel: