[shardformer] import huggingface implicitly (#4101)

This commit is contained in:
Frank Lee
2023-06-30 10:56:29 +08:00
parent 6a88bae4ec
commit 44a190e6ac
9 changed files with 91 additions and 38 deletions

View File

@@ -1,15 +1,4 @@
from transformers import T5ForConditionalGeneration
from transformers.models.t5.modeling_t5 import (
T5Attention,
T5DenseActDense,
T5DenseGatedActDense,
T5LayerCrossAttention,
T5LayerFF,
T5LayerSelfAttention,
T5Stack,
)
from colossalai.shardformer.layer import DropoutForParallelInput, Embedding1D, FusedRMSNorm, Linear1D_Col, Linear1D_Row
from colossalai.shardformer.layer import DropoutForParallelInput, Embedding1D, Linear1D_Col, Linear1D_Row
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
@@ -34,7 +23,17 @@ class T5ModelPolicy(Policy):
return self.model
def module_policy(self):
base_policy = {
from transformers.models.t5.modeling_t5 import (
T5Attention,
T5DenseActDense,
T5DenseGatedActDense,
T5LayerCrossAttention,
T5LayerFF,
T5LayerSelfAttention,
T5Stack,
)
return {
T5Stack:
ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
@@ -165,6 +164,8 @@ class T5ModelPolicy(Policy):
class T5ForConditionalGenerationPolicy(T5ModelPolicy):
def module_policy(self):
from transformers import T5ForConditionalGeneration
policy = super().module_policy()
new_item = {