[shardformer] Add layernorm (#4072)

* add layernorm to bert

* add layernorm test

* add layernorm test with load state dict

* add use_mixedfusedLN in shard config

* refactor policy to support fused_layernorm
This commit is contained in:
FoolPlayer
2023-06-23 18:00:22 +08:00
committed by Frank Lee
parent 70c58cfd4f
commit 92f6791095
7 changed files with 252 additions and 17 deletions

View File

@@ -8,7 +8,7 @@ def build_model(world_size, model_fn):
org_model = model_fn().cuda()
# shard model
shard_config = ShardConfig(tensor_parallel_size=world_size)
shard_config = ShardConfig(tensor_parallel_size=world_size, fused_layernorm=True)
model_copy = copy.deepcopy(org_model)
shard_former = ShardFormer(shard_config=shard_config)
shard_former.init_distributed()