mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 02:51:59 +00:00
[shardformer] made tensor parallelism configurable (#4144)
* [shardformer] made tensor parallelism configurable * polish code
This commit is contained in:
@@ -3,12 +3,13 @@ import copy
|
||||
from colossalai.shardformer import ShardConfig, ShardFormer
|
||||
|
||||
|
||||
def build_model(model_fn):
|
||||
def build_model(model_fn, enable_fused_normalization=True, enable_tensor_parallelism=True):
|
||||
# create new model
|
||||
org_model = model_fn().cuda()
|
||||
|
||||
# shard model
|
||||
shard_config = ShardConfig(enable_fused_normalization=True)
|
||||
shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization,
|
||||
enable_tensor_parallelism=enable_tensor_parallelism)
|
||||
model_copy = copy.deepcopy(org_model)
|
||||
shard_former = ShardFormer(shard_config=shard_config)
|
||||
sharded_model = shard_former.optimize(model_copy).cuda()
|
||||
|
Reference in New Issue
Block a user