mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 09:07:51 +00:00
[zero] refactory ShardedOptimV2 init method (#416)
This commit is contained in:
@@ -52,7 +52,6 @@ def run_dist(rank, world_size, port, cpu_offload, shard_strategy):
|
||||
optim = optimizer_class(model.parameters(), lr=lr)
|
||||
sharded_optim = ShardedOptimizerV2(zero_model,
|
||||
optimizer_class,
|
||||
shard_strategy,
|
||||
cpu_offload=cpu_offload,
|
||||
initial_scale=2**5,
|
||||
lr=lr)
|
||||
|
@@ -59,12 +59,7 @@ def run_dist(rank, world_size, port, shard_strategy):
|
||||
if dist.get_world_size() > 1:
|
||||
model = DDP(model)
|
||||
optim = Adam(model.parameters(), lr=1e-3)
|
||||
sharded_optim = ShardedOptimizerV2(zero_model,
|
||||
CPUAdam,
|
||||
shard_strategy,
|
||||
initial_scale=2**5,
|
||||
cpu_offload=True,
|
||||
lr=1e-3)
|
||||
sharded_optim = ShardedOptimizerV2(zero_model, CPUAdam, initial_scale=2**5, cpu_offload=True, lr=1e-3)
|
||||
for i, (data, label) in enumerate(train_dataloader):
|
||||
if i > 2:
|
||||
break
|
||||
|
Reference in New Issue
Block a user