mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-27 04:33:04 +00:00
[shardformer] import huggingface implicitly (#4101)
This commit is contained in:
@@ -1,11 +1,15 @@
|
||||
import torch.nn as nn
|
||||
from transformers.models.gpt2.modeling_gpt2 import GPT2Block, GPT2DoubleHeadsModel, GPT2LMHeadModel, GPT2Model
|
||||
|
||||
import colossalai.shardformer.layer as col_nn
|
||||
|
||||
from .._utils import getattr_, setattr_
|
||||
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||
|
||||
__all__ = [
|
||||
'GPT2Policy', 'GPT2ModelPolicy', 'GPT2LMHeadModelPolicy', 'GPT2DoubleHeadsModelPolicy',
|
||||
'GPT2ForTokenClassificationPolicy', 'GPT2ForSequenceClassificationPolicy'
|
||||
]
|
||||
|
||||
|
||||
class GPT2Policy(Policy):
|
||||
|
||||
@@ -25,7 +29,9 @@ class GPT2Policy(Policy):
|
||||
return self.model
|
||||
|
||||
def module_policy(self):
|
||||
base_policy = {
|
||||
from transformers.models.gpt2.modeling_gpt2 import GPT2Block, GPT2Model
|
||||
|
||||
return {
|
||||
GPT2Model:
|
||||
ModulePolicyDescription(attribute_replacement={},
|
||||
param_replacement=[],
|
||||
@@ -125,6 +131,8 @@ class GPT2LMHeadModelPolicy(GPT2Policy):
|
||||
super().__init__()
|
||||
|
||||
def module_policy(self):
|
||||
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel
|
||||
|
||||
module_policy = super().module_policy()
|
||||
addon_module = {
|
||||
GPT2LMHeadModel:
|
||||
@@ -156,6 +164,8 @@ class GPT2DoubleHeadsModelPolicy(GPT2Policy):
|
||||
super().__init__()
|
||||
|
||||
def module_policy(self):
|
||||
from transformers.models.gpt2.modeling_gpt2 import GPT2DoubleHeadsModel
|
||||
|
||||
module_policy = super().module_policy()
|
||||
addon_module = {
|
||||
GPT2DoubleHeadsModel:
|
||||
|
Reference in New Issue
Block a user