mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 02:26:51 +00:00
[shardformer] Add layernorm (#4072)
* add layernorm to bert * add layernorm test * add layernorm test with load state dict * add use_mixedfusedLN in shard config * refactor policy to support fused_layernorm
This commit is contained in:
@@ -8,7 +8,7 @@ def build_model(world_size, model_fn):
|
||||
org_model = model_fn().cuda()
|
||||
|
||||
# shard model
|
||||
shard_config = ShardConfig(tensor_parallel_size=world_size)
|
||||
shard_config = ShardConfig(tensor_parallel_size=world_size, fused_layernorm=True)
|
||||
model_copy = copy.deepcopy(org_model)
|
||||
shard_former = ShardFormer(shard_config=shard_config)
|
||||
shard_former.init_distributed()
|
||||
|
Reference in New Issue
Block a user