[shardformer] refactored the shardformer layer structure (#4053)

This commit is contained in:
Frank Lee
2023-06-21 14:30:06 +08:00
parent 58df720570
commit f22ddacef0
24 changed files with 196 additions and 471 deletions

View File

@@ -1,12 +1,7 @@
from typing import Type, Union
import torch.nn as nn
from transformers.models.gpt2.modeling_gpt2 import GPT2Block, GPT2Model
import colossalai.shardformer.layer as col_nn
from colossalai.shardformer.layer.dropout import Dropout1D
from ..utils import getattr_, setattr_
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription