[shardformer] update shardformer readme (#4689)

* [shardformer] update shardformer readme

* [shardformer] update shardformer readme

* [shardformer] update shardformer readme

* [shardformer] update shardformer readme

* [shardformer] update shardformer readme
This commit is contained in:
flybird11111
2023-09-12 15:14:24 +08:00
committed by GitHub
parent 1d454733c4
commit 8844691f4b
4 changed files with 90 additions and 72 deletions

View File

@@ -49,9 +49,12 @@ def train(args):
# if multiple GPUs, shard the model
if dist.get_world_size() > 1:
shard_config = ShardConfig(enable_fused_normalization=args.fused_layernorm)
tp_group = dist.new_group(backend='nccl')
shard_config = ShardConfig(tensor_parallel_process_group=tp_group,
enable_tensor_parallelism=True,
enable_all_optimization=True)
shard_former = ShardFormer(shard_config=shard_config)
model = shard_former.optimize(model)
model, _ = shard_former.optimize(model)
optim = Adam(model.parameters(), lr=args.lr)
num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps