[shardformer] fix bert and gpt downstream with new api (#4024)

* fix bert downstream with new api

* remove comment line
This commit is contained in:
FoolPlayer
2023-06-19 10:47:16 +08:00
committed by Frank Lee
parent e253a07007
commit 74d176c8d8
6 changed files with 97 additions and 39 deletions

View File

@@ -7,7 +7,6 @@ from transformers import (
AutoTokenizer,
BertConfig,
BertForMaskedLM,
BertForMultipleChoice,
BertForNextSentencePrediction,
BertForPreTraining,
BertForSequenceClassification,
@@ -36,12 +35,10 @@ def build_model(rank, world_size, model):
org_model.to('cuda')
# TODO: no need to transfer to cuda
org_model_forshard.to('cuda')
shard_config = ShardConfig(tensor_parallel_size=2,
data_parallel_size=1,
pipeline_parallel_size=1,
tensor_parallel_mode='1d',
inference_only=True,
gather_output=True)
shard_config = ShardConfig(
tensor_parallel_size=2,
tensor_parallel_mode='1d',
)
shard_former = ShardFormer(shard_config=shard_config)
shard_former.init_distributed()
sharded_model = shard_former.shard_model(org_model_forshard).to('cuda')