mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-16 06:30:41 +00:00
[checkpoint] add ColoOptimizer checkpointing (#1316)
This commit is contained in:
@@ -77,6 +77,18 @@ def remove(path):
|
||||
raise ValueError("file {} is not a file or dir.".format(path))
|
||||
|
||||
|
||||
def compare_optims(optim1, optim2):
|
||||
state1 = optim1.state_dict()['state']
|
||||
state2 = optim2.state_dict()['state']
|
||||
for k, p1 in state1.items():
|
||||
if k not in state2:
|
||||
continue
|
||||
p2 = state2[k]
|
||||
if isinstance(p1, ColoTensor):
|
||||
assert isinstance(p2, ColoTensor)
|
||||
assert torch.allclose(p1.to_replicate_(), p2.to_replicate_(), rtol=1e-3, atol=1e-1)
|
||||
|
||||
|
||||
def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_scheduler, pg):
|
||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||
@@ -117,7 +129,10 @@ def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_sch
|
||||
model_reload = model_reload.cuda()
|
||||
model_reload.train()
|
||||
|
||||
colo_optimizer = ColossalaiOptimizer(torch.optim.SGD(model.named_parameters(), r=0.1))
|
||||
opt_class = torch.optim.Adam
|
||||
colo_optimizer = ColossalaiOptimizer(opt_class(model.parameters(), lr=0.1))
|
||||
colo_optimizer_reload = ColossalaiOptimizer(opt_class(model_reload.parameters(), lr=0.1))
|
||||
run_reload = False
|
||||
|
||||
for i, (data, label) in enumerate(train_dataloader):
|
||||
|
||||
@@ -130,22 +145,35 @@ def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_sch
|
||||
# Bcast rank0 data to all processes
|
||||
if criterion:
|
||||
output = model(data)
|
||||
output_reload = model_reload(data)
|
||||
loss = criterion(output, label)
|
||||
loss_reload = criterion(output_reload, label)
|
||||
else:
|
||||
output = model(data, label)
|
||||
loss = output
|
||||
loss = model(data, label)
|
||||
loss_reload = model_reload(data, label)
|
||||
|
||||
loss.backward()
|
||||
colo_optimizer.step()
|
||||
loss_reload.backward()
|
||||
|
||||
if run_reload:
|
||||
colo_optimizer_reload.zero_grad()
|
||||
if criterion:
|
||||
output_reload = model_reload(data)
|
||||
loss_reload = criterion(output_reload, label)
|
||||
else:
|
||||
loss_reload = model_reload(data, label)
|
||||
loss_reload.backward()
|
||||
colo_optimizer_reload.step()
|
||||
|
||||
if i > 2:
|
||||
break
|
||||
|
||||
if not os.path.isdir('./checkpoint') and rank == 0:
|
||||
os.mkdir('./checkpoint')
|
||||
save_checkpoint('./checkpoint', 0, model, None, None)
|
||||
save_checkpoint('./checkpoint', 0, model, colo_optimizer, None)
|
||||
dist.barrier()
|
||||
load_checkpoint('./checkpoint', 0, model_reload, colo_optimizer_reload, None)
|
||||
dist.barrier()
|
||||
load_checkpoint('./checkpoint', 0, model_reload, None, None)
|
||||
|
||||
# Since model is sharded, we merge them before param checking.
|
||||
for p in model.parameters():
|
||||
@@ -155,7 +183,7 @@ def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_sch
|
||||
p.to_replicate_()
|
||||
|
||||
check_param_equal(model, model_reload)
|
||||
|
||||
compare_optims(colo_optimizer, colo_optimizer_reload)
|
||||
if rank == 0:
|
||||
remove('./checkpoint')
|
||||
|
||||
@@ -163,7 +191,7 @@ def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_sch
|
||||
def run_dist(rank, world_size, port, use_ddp, use_mp_reload, test_scheduler):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
pg = ProcessGroup(tp_degree=world_size)
|
||||
for model_name in ['bert', 'simple_net']:
|
||||
for model_name in ['simple_net', 'bert']:
|
||||
_run_checkpoint(model_name,
|
||||
init_1d_row_for_linear_weight_spec,
|
||||
use_ddp,
|
||||
|
Reference in New Issue
Block a user