[shardformer] adapted T5 and LLaMa test to use kit (#4049)

* [shardformer] adapted T5 and LLaMa test to use kit

* polish code
This commit is contained in:
Frank Lee
2023-06-21 09:32:46 +08:00
parent 4021b9a8a2
commit 58df720570
24 changed files with 239 additions and 168 deletions

View File

@@ -4,7 +4,7 @@ import torch.nn as nn
from transformers import LlamaForCausalLM, LlamaForSequenceClassification
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel
from colossalai.shardformer.layer.layers import Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D
from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription