mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 17:17:05 +00:00
[checkpoint] save sharded optimizer states (#1237)
This commit is contained in:
@@ -126,6 +126,9 @@ def run_checkpoint(init_spec_func, use_ddp, test_epoch, test_scheduler, pg):
|
||||
model_reload = ColoDDP(model_reload, pg)
|
||||
model_ref = ColoDDP(model_ref, pg)
|
||||
|
||||
init_spec_func(model, pg)
|
||||
init_spec_func(model_ref, pg)
|
||||
|
||||
criterion = torch.nn.CrossEntropyLoss()
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
|
||||
optimizer_reload = torch.optim.Adam(model_reload.parameters(),
|
||||
@@ -135,23 +138,21 @@ def run_checkpoint(init_spec_func, use_ddp, test_epoch, test_scheduler, pg):
|
||||
weight_decay=0)
|
||||
optimizer_ref = torch.optim.Adam(model_ref.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
|
||||
|
||||
lr_scheduler = None
|
||||
if test_scheduler == 'colossalai_cosine_warmup':
|
||||
lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer, total_steps=num_epoch, warmup_steps=warmup_epoch)
|
||||
lr_scheduler_reload = CosineAnnealingWarmupLR(optimizer=optimizer_reload,
|
||||
total_steps=num_epoch,
|
||||
warmup_steps=warmup_epoch)
|
||||
|
||||
elif test_scheduler == 'torch_cosine':
|
||||
lr_scheduler = CosineAnnealingLR(optimizer=optimizer, T_max=num_epoch)
|
||||
lr_scheduler_reload = CosineAnnealingLR(optimizer=optimizer_reload, T_max=num_epoch)
|
||||
|
||||
elif test_scheduler == 'torch_lambda':
|
||||
lr_lambda = lambda epoch: 0.95
|
||||
lr_scheduler = MultiplicativeLR(optimizer=optimizer, lr_lambda=lr_lambda)
|
||||
lr_scheduler_reload = MultiplicativeLR(optimizer=optimizer_reload, lr_lambda=lr_lambda)
|
||||
|
||||
init_spec_func(model, pg)
|
||||
init_spec_func(model_ref, pg)
|
||||
else:
|
||||
raise TypeError(f"{test_scheduler} is invalid")
|
||||
|
||||
for epoch in range(0, num_epoch):
|
||||
if epoch <= test_epoch:
|
||||
@@ -212,7 +213,11 @@ def run_dist(rank, world_size, port, use_ddp, test_epoch, test_scheduler):
|
||||
config = dict(parallel=dict(tensor=dict(mode="1d", size=tp_world_size),))
|
||||
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
pg = ProcessGroup(tp_degree=world_size)
|
||||
run_checkpoint(init_1d_row_for_linear_weight_spec, use_ddp, test_epoch, test_scheduler, pg)
|
||||
run_checkpoint(init_1d_row_for_linear_weight_spec,
|
||||
use_ddp,
|
||||
test_epoch=test_epoch,
|
||||
test_scheduler=test_scheduler,
|
||||
pg=pg)
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
@@ -236,4 +241,4 @@ def test_checkpoint(world_size, use_ddp, test_epoch, test_scheduler):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_checkpoint(4, True, 1, 1)
|
||||
test_checkpoint(4, True, 1, "colossalai_cosine_warmup")
|
||||
|
Reference in New Issue
Block a user