mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-27 12:43:02 +00:00
[shardformer] import huggingface implicitly (#4101)
This commit is contained in:
@@ -1,15 +1,4 @@
|
||||
from transformers import T5ForConditionalGeneration
|
||||
from transformers.models.t5.modeling_t5 import (
|
||||
T5Attention,
|
||||
T5DenseActDense,
|
||||
T5DenseGatedActDense,
|
||||
T5LayerCrossAttention,
|
||||
T5LayerFF,
|
||||
T5LayerSelfAttention,
|
||||
T5Stack,
|
||||
)
|
||||
|
||||
from colossalai.shardformer.layer import DropoutForParallelInput, Embedding1D, FusedRMSNorm, Linear1D_Col, Linear1D_Row
|
||||
from colossalai.shardformer.layer import DropoutForParallelInput, Embedding1D, Linear1D_Col, Linear1D_Row
|
||||
|
||||
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||
|
||||
@@ -34,7 +23,17 @@ class T5ModelPolicy(Policy):
|
||||
return self.model
|
||||
|
||||
def module_policy(self):
|
||||
base_policy = {
|
||||
from transformers.models.t5.modeling_t5 import (
|
||||
T5Attention,
|
||||
T5DenseActDense,
|
||||
T5DenseGatedActDense,
|
||||
T5LayerCrossAttention,
|
||||
T5LayerFF,
|
||||
T5LayerSelfAttention,
|
||||
T5Stack,
|
||||
)
|
||||
|
||||
return {
|
||||
T5Stack:
|
||||
ModulePolicyDescription(attribute_replacement={},
|
||||
param_replacement=[],
|
||||
@@ -165,6 +164,8 @@ class T5ModelPolicy(Policy):
|
||||
class T5ForConditionalGenerationPolicy(T5ModelPolicy):
|
||||
|
||||
def module_policy(self):
|
||||
from transformers import T5ForConditionalGeneration
|
||||
|
||||
policy = super().module_policy()
|
||||
|
||||
new_item = {
|
||||
|
Reference in New Issue
Block a user