[checkpoint] add ColoOptimizer checkpointing (#1316)

This commit is contained in:
Jiarui Fang
2022-07-15 09:52:55 +08:00
committed by GitHub
parent 7c2634f4b3
commit 9e4c6449b0
3 changed files with 74 additions and 15 deletions

View File

@@ -77,6 +77,18 @@ def remove(path):
raise ValueError("file {} is not a file or dir.".format(path))
def compare_optims(optim1, optim2):
state1 = optim1.state_dict()['state']
state2 = optim2.state_dict()['state']
for k, p1 in state1.items():
if k not in state2:
continue
p2 = state2[k]
if isinstance(p1, ColoTensor):
assert isinstance(p2, ColoTensor)
assert torch.allclose(p1.to_replicate_(), p2.to_replicate_(), rtol=1e-3, atol=1e-1)
def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_scheduler, pg):
get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
@@ -117,7 +129,10 @@ def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_sch
model_reload = model_reload.cuda()
model_reload.train()
colo_optimizer = ColossalaiOptimizer(torch.optim.SGD(model.named_parameters(), r=0.1))
opt_class = torch.optim.Adam
colo_optimizer = ColossalaiOptimizer(opt_class(model.parameters(), lr=0.1))
colo_optimizer_reload = ColossalaiOptimizer(opt_class(model_reload.parameters(), lr=0.1))
run_reload = False
for i, (data, label) in enumerate(train_dataloader):
@@ -130,22 +145,35 @@ def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_sch
# Bcast rank0 data to all processes
if criterion:
output = model(data)
output_reload = model_reload(data)
loss = criterion(output, label)
loss_reload = criterion(output_reload, label)
else:
output = model(data, label)
loss = output
loss = model(data, label)
loss_reload = model_reload(data, label)
loss.backward()
colo_optimizer.step()
loss_reload.backward()
if run_reload:
colo_optimizer_reload.zero_grad()
if criterion:
output_reload = model_reload(data)
loss_reload = criterion(output_reload, label)
else:
loss_reload = model_reload(data, label)
loss_reload.backward()
colo_optimizer_reload.step()
if i > 2:
break
if not os.path.isdir('./checkpoint') and rank == 0:
os.mkdir('./checkpoint')
save_checkpoint('./checkpoint', 0, model, None, None)
save_checkpoint('./checkpoint', 0, model, colo_optimizer, None)
dist.barrier()
load_checkpoint('./checkpoint', 0, model_reload, colo_optimizer_reload, None)
dist.barrier()
load_checkpoint('./checkpoint', 0, model_reload, None, None)
# Since model is sharded, we merge them before param checking.
for p in model.parameters():
@@ -155,7 +183,7 @@ def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_sch
p.to_replicate_()
check_param_equal(model, model_reload)
compare_optims(colo_optimizer, colo_optimizer_reload)
if rank == 0:
remove('./checkpoint')
@@ -163,7 +191,7 @@ def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_sch
def run_dist(rank, world_size, port, use_ddp, use_mp_reload, test_scheduler):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
pg = ProcessGroup(tp_degree=world_size)
for model_name in ['bert', 'simple_net']:
for model_name in ['simple_net', 'bert']:
_run_checkpoint(model_name,
init_1d_row_for_linear_weight_spec,
use_ddp,