mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-12 20:54:35 +00:00
[shardformer] update shardformer readme (#4689)
* [shardformer] update shardformer readme * [shardformer] update shardformer readme * [shardformer] update shardformer readme * [shardformer] update shardformer readme * [shardformer] update shardformer readme
This commit is contained in:
@@ -49,9 +49,12 @@ def train(args):
|
||||
|
||||
# if multiple GPUs, shard the model
|
||||
if dist.get_world_size() > 1:
|
||||
shard_config = ShardConfig(enable_fused_normalization=args.fused_layernorm)
|
||||
tp_group = dist.new_group(backend='nccl')
|
||||
shard_config = ShardConfig(tensor_parallel_process_group=tp_group,
|
||||
enable_tensor_parallelism=True,
|
||||
enable_all_optimization=True)
|
||||
shard_former = ShardFormer(shard_config=shard_config)
|
||||
model = shard_former.optimize(model)
|
||||
model, _ = shard_former.optimize(model)
|
||||
|
||||
optim = Adam(model.parameters(), lr=args.lr)
|
||||
num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps
|
||||
|
Reference in New Issue
Block a user