mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 13:00:52 +00:00
integrate with dist layer (#4011)
This commit is contained in:
@@ -17,7 +17,7 @@ from transformers import (
|
||||
|
||||
import colossalai
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.shardformer.shard import ShardConfig, shard_model
|
||||
from colossalai.shardformer import ShardConfig, ShardFormer
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
|
||||
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
|
||||
@@ -30,16 +30,21 @@ def build_model(rank, world_size, model):
|
||||
config.hidden_dropout_prob = 0
|
||||
config.attention_probs_dropout_prob = 0
|
||||
|
||||
org_model = model(config=config)
|
||||
org_model = BertForMaskedLM.from_pretrained('bert-base-uncased', config=config)
|
||||
org_model_forshard = copy.deepcopy(org_model)
|
||||
|
||||
org_model = org_model.to('cuda')
|
||||
shardconfig = ShardConfig(
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
gather_output=True,
|
||||
)
|
||||
sharded_model = shard_model(org_model_forshard, shardconfig).to('cuda')
|
||||
org_model.to('cuda')
|
||||
# TODO: no need to transfer to cuda
|
||||
org_model_forshard.to('cuda')
|
||||
shard_config = ShardConfig(tensor_parallel_size=2,
|
||||
data_parallel_size=1,
|
||||
pipeline_parallel_size=1,
|
||||
tensor_parallel_mode='1d',
|
||||
inference_only=True,
|
||||
gather_output=True)
|
||||
shard_former = ShardFormer(shard_config=shard_config)
|
||||
shard_former.init_distributed()
|
||||
sharded_model = shard_former.shard_model(org_model_forshard).to('cuda')
|
||||
|
||||
return org_model, sharded_model
|
||||
|
||||
|
Reference in New Issue
Block a user