[checkpoint] save sharded optimizer states (#1237)

This commit is contained in:
Jiarui Fang
2022-07-08 16:33:13 +08:00
committed by GitHub
parent 4a76084dc9
commit 20da6e48c8
3 changed files with 28 additions and 19 deletions

View File

@@ -126,6 +126,9 @@ def run_checkpoint(init_spec_func, use_ddp, test_epoch, test_scheduler, pg):
model_reload = ColoDDP(model_reload, pg)
model_ref = ColoDDP(model_ref, pg)
init_spec_func(model, pg)
init_spec_func(model_ref, pg)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
optimizer_reload = torch.optim.Adam(model_reload.parameters(),
@@ -135,23 +138,21 @@ def run_checkpoint(init_spec_func, use_ddp, test_epoch, test_scheduler, pg):
weight_decay=0)
optimizer_ref = torch.optim.Adam(model_ref.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
lr_scheduler = None
if test_scheduler == 'colossalai_cosine_warmup':
lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer, total_steps=num_epoch, warmup_steps=warmup_epoch)
lr_scheduler_reload = CosineAnnealingWarmupLR(optimizer=optimizer_reload,
total_steps=num_epoch,
warmup_steps=warmup_epoch)
elif test_scheduler == 'torch_cosine':
lr_scheduler = CosineAnnealingLR(optimizer=optimizer, T_max=num_epoch)
lr_scheduler_reload = CosineAnnealingLR(optimizer=optimizer_reload, T_max=num_epoch)
elif test_scheduler == 'torch_lambda':
lr_lambda = lambda epoch: 0.95
lr_scheduler = MultiplicativeLR(optimizer=optimizer, lr_lambda=lr_lambda)
lr_scheduler_reload = MultiplicativeLR(optimizer=optimizer_reload, lr_lambda=lr_lambda)
init_spec_func(model, pg)
init_spec_func(model_ref, pg)
else:
raise TypeError(f"{test_scheduler} is invalid")
for epoch in range(0, num_epoch):
if epoch <= test_epoch:
@@ -212,7 +213,11 @@ def run_dist(rank, world_size, port, use_ddp, test_epoch, test_scheduler):
config = dict(parallel=dict(tensor=dict(mode="1d", size=tp_world_size),))
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
pg = ProcessGroup(tp_degree=world_size)
run_checkpoint(init_1d_row_for_linear_weight_spec, use_ddp, test_epoch, test_scheduler, pg)
run_checkpoint(init_1d_row_for_linear_weight_spec,
use_ddp,
test_epoch=test_epoch,
test_scheduler=test_scheduler,
pg=pg)
@pytest.mark.skip
@@ -236,4 +241,4 @@ def test_checkpoint(world_size, use_ddp, test_epoch, test_scheduler):
if __name__ == '__main__':
test_checkpoint(4, True, 1, 1)
test_checkpoint(4, True, 1, "colossalai_cosine_warmup")