integrate with dist layer (#4011)

This commit is contained in:
FoolPlayer
2023-06-16 11:23:30 +08:00
committed by Frank Lee
parent 015af592f8
commit dfca9678fa
3 changed files with 42 additions and 24 deletions

View File

@@ -17,7 +17,7 @@ from transformers import (
import colossalai
from colossalai.logging import disable_existing_loggers
from colossalai.shardformer.shard import ShardConfig, shard_model
from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.testing import rerun_if_address_is_in_use, spawn
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
@@ -30,16 +30,21 @@ def build_model(rank, world_size, model):
config.hidden_dropout_prob = 0
config.attention_probs_dropout_prob = 0
org_model = model(config=config)
org_model = BertForMaskedLM.from_pretrained('bert-base-uncased', config=config)
org_model_forshard = copy.deepcopy(org_model)
org_model = org_model.to('cuda')
shardconfig = ShardConfig(
rank=rank,
world_size=world_size,
gather_output=True,
)
sharded_model = shard_model(org_model_forshard, shardconfig).to('cuda')
org_model.to('cuda')
# TODO: no need to transfer to cuda
org_model_forshard.to('cuda')
shard_config = ShardConfig(tensor_parallel_size=2,
data_parallel_size=1,
pipeline_parallel_size=1,
tensor_parallel_mode='1d',
inference_only=True,
gather_output=True)
shard_former = ShardFormer(shard_config=shard_config)
shard_former.init_distributed()
sharded_model = shard_former.shard_model(org_model_forshard).to('cuda')
return org_model, sharded_model