[shardformer] integrate with data parallelism (#4103)

This commit is contained in:
Frank Lee
2023-06-30 09:58:08 +08:00
parent f3b6aaa6b7
commit 6a88bae4ec
11 changed files with 97 additions and 50 deletions

View File

@@ -3,17 +3,15 @@ import copy
from colossalai.shardformer import ShardConfig, ShardFormer
def build_model(world_size, model_fn):
def build_model(model_fn):
# create new model
org_model = model_fn().cuda()
# shard model
shard_config = ShardConfig(tensor_parallel_size=world_size, enable_fused_normalization=True)
shard_config = ShardConfig(enable_fused_normalization=True)
model_copy = copy.deepcopy(org_model)
shard_former = ShardFormer(shard_config=shard_config)
shard_former.init_distributed()
sharded_model = shard_former.shard_model(model_copy).cuda()
return org_model, sharded_model