[shardformer] supported fused normalization (#4112)

This commit is contained in:
Frank Lee
2023-06-30 09:32:37 +08:00
parent b1c2901530
commit f3b6aaa6b7
12 changed files with 207 additions and 31 deletions

View File

@@ -8,11 +8,11 @@ def build_model(world_size, model_fn):
org_model = model_fn().cuda()
# shard model
shard_config = ShardConfig(tensor_parallel_size=world_size, fused_layernorm=True)
shard_config = ShardConfig(tensor_parallel_size=world_size, enable_fused_normalization=True)
model_copy = copy.deepcopy(org_model)
shard_former = ShardFormer(shard_config=shard_config)
shard_former.init_distributed()
sharded_model = shard_former.shard_model(model_copy)
sharded_model = shard_former.shard_model(model_copy).cuda()
return org_model, sharded_model
@@ -33,4 +33,4 @@ def run_forward(original_model, sharded_model, data_gen_fn, output_transform_fn,
shard_output = sharded_model(**data)
shard_output = output_transform_fn(shard_output)
shard_loss = loss_fn(shard_output)
return org_output, org_loss, shard_output, shard_loss
return org_output, org_loss, shard_output, shard_loss