mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-27 20:46:00 +00:00
[shardformer] import huggingface implicitly (#4101)
This commit is contained in:
@@ -1,15 +1,12 @@
|
||||
from transformers.models.opt.modeling_opt import (
|
||||
OPTAttention,
|
||||
OPTDecoder,
|
||||
OPTDecoderLayer,
|
||||
OPTForCausalLM,
|
||||
OPTForSequenceClassification,
|
||||
)
|
||||
|
||||
from colossalai.shardformer.layer import Embedding1D, FusedLayerNorm, Linear1D_Col, Linear1D_Row
|
||||
|
||||
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||
|
||||
__all__ = [
|
||||
'OPTPolicy', 'OPTModelPolicy', 'OPTForCausalLMPolicy', 'OPTForSequenceClassificationPolicy',
|
||||
'OPTForQuestionAnsweringPolicy'
|
||||
]
|
||||
|
||||
|
||||
class OPTPolicy(Policy):
|
||||
|
||||
@@ -29,6 +26,8 @@ class OPTPolicy(Policy):
|
||||
return self.model
|
||||
|
||||
def module_policy(self):
|
||||
from transformers.models.opt.modeling_opt import OPTAttention, OPTDecoder, OPTDecoderLayer
|
||||
|
||||
base_policy = {
|
||||
OPTDecoder:
|
||||
ModulePolicyDescription(attribute_replacement={},
|
||||
@@ -111,6 +110,8 @@ class OPTModelPolicy(OPTPolicy):
|
||||
class OPTForCausalLMPolicy(OPTPolicy):
|
||||
|
||||
def module_policy(self):
|
||||
from transformers.models.opt.modeling_opt import OPTForCausalLM
|
||||
|
||||
policy = super().module_policy()
|
||||
new_item = {
|
||||
OPTForCausalLM:
|
||||
|
Reference in New Issue
Block a user