[zero] refactory ShardedOptimV2 init method (#416)

This commit is contained in:
Jiarui Fang
2022-03-15 10:45:55 +08:00
committed by GitHub
parent e79ea44247
commit 23ba3fc450
3 changed files with 32 additions and 16 deletions

View File

@@ -52,7 +52,6 @@ def run_dist(rank, world_size, port, cpu_offload, shard_strategy):
optim = optimizer_class(model.parameters(), lr=lr)
sharded_optim = ShardedOptimizerV2(zero_model,
optimizer_class,
shard_strategy,
cpu_offload=cpu_offload,
initial_scale=2**5,
lr=lr)

View File

@@ -59,12 +59,7 @@ def run_dist(rank, world_size, port, shard_strategy):
if dist.get_world_size() > 1:
model = DDP(model)
optim = Adam(model.parameters(), lr=1e-3)
sharded_optim = ShardedOptimizerV2(zero_model,
CPUAdam,
shard_strategy,
initial_scale=2**5,
cpu_offload=True,
lr=1e-3)
sharded_optim = ShardedOptimizerV2(zero_model, CPUAdam, initial_scale=2**5, cpu_offload=True, lr=1e-3)
for i, (data, label) in enumerate(train_dataloader):
if i > 2:
break