[shardformer] made tensor parallelism configurable (#4144)

* [shardformer] made tensor parallelism configurable

* polish code
This commit is contained in:
Frank Lee
2023-07-04 09:57:03 +08:00
parent 74257cb446
commit 1fb0d95df0
15 changed files with 819 additions and 673 deletions

View File

@@ -3,12 +3,13 @@ import copy
from colossalai.shardformer import ShardConfig, ShardFormer
def build_model(model_fn):
def build_model(model_fn, enable_fused_normalization=True, enable_tensor_parallelism=True):
# create new model
org_model = model_fn().cuda()
# shard model
shard_config = ShardConfig(enable_fused_normalization=True)
shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization,
enable_tensor_parallelism=enable_tensor_parallelism)
model_copy = copy.deepcopy(org_model)
shard_former = ShardFormer(shard_config=shard_config)
sharded_model = shard_former.optimize(model_copy).cuda()