mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 13:00:52 +00:00
[shardformer] fix bert and gpt downstream with new api (#4024)
* fix bert downstream with new api * remove comment line
This commit is contained in:
@@ -7,7 +7,6 @@ from transformers import (
|
||||
AutoTokenizer,
|
||||
BertConfig,
|
||||
BertForMaskedLM,
|
||||
BertForMultipleChoice,
|
||||
BertForNextSentencePrediction,
|
||||
BertForPreTraining,
|
||||
BertForSequenceClassification,
|
||||
@@ -36,12 +35,10 @@ def build_model(rank, world_size, model):
|
||||
org_model.to('cuda')
|
||||
# TODO: no need to transfer to cuda
|
||||
org_model_forshard.to('cuda')
|
||||
shard_config = ShardConfig(tensor_parallel_size=2,
|
||||
data_parallel_size=1,
|
||||
pipeline_parallel_size=1,
|
||||
tensor_parallel_mode='1d',
|
||||
inference_only=True,
|
||||
gather_output=True)
|
||||
shard_config = ShardConfig(
|
||||
tensor_parallel_size=2,
|
||||
tensor_parallel_mode='1d',
|
||||
)
|
||||
shard_former = ShardFormer(shard_config=shard_config)
|
||||
shard_former.init_distributed()
|
||||
sharded_model = shard_former.shard_model(org_model_forshard).to('cuda')
|
||||
|
Reference in New Issue
Block a user