[shardformer] import huggingface implicitly (#4101)

This commit is contained in:
Frank Lee
2023-06-30 10:56:29 +08:00
parent 6a88bae4ec
commit 44a190e6ac
9 changed files with 91 additions and 38 deletions

View File

@@ -1,15 +1,12 @@
from transformers.models.opt.modeling_opt import (
OPTAttention,
OPTDecoder,
OPTDecoderLayer,
OPTForCausalLM,
OPTForSequenceClassification,
)
from colossalai.shardformer.layer import Embedding1D, FusedLayerNorm, Linear1D_Col, Linear1D_Row
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = [
'OPTPolicy', 'OPTModelPolicy', 'OPTForCausalLMPolicy', 'OPTForSequenceClassificationPolicy',
'OPTForQuestionAnsweringPolicy'
]
class OPTPolicy(Policy):
@@ -29,6 +26,8 @@ class OPTPolicy(Policy):
return self.model
def module_policy(self):
from transformers.models.opt.modeling_opt import OPTAttention, OPTDecoder, OPTDecoderLayer
base_policy = {
OPTDecoder:
ModulePolicyDescription(attribute_replacement={},
@@ -111,6 +110,8 @@ class OPTModelPolicy(OPTPolicy):
class OPTForCausalLMPolicy(OPTPolicy):
def module_policy(self):
from transformers.models.opt.modeling_opt import OPTForCausalLM
policy = super().module_policy()
new_item = {
OPTForCausalLM: