[checkpoint] add ColoOptimizer checkpointing (#1316)

This commit is contained in:
Jiarui Fang 2022-07-15 09:52:55 +08:00 committed by GitHub
parent 7c2634f4b3
commit 9e4c6449b0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 74 additions and 15 deletions

View File

@ -1,6 +1,3 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch import Tensor from torch import Tensor

View File

@ -1,12 +1,15 @@
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from colossalai.tensor import ColoTensor, DistSpecManager from colossalai.tensor import ColoTensor, DistSpecManager
from colossalai.nn.optimizer import ColossalaiOptimizer
from copy import copy
from typing import Optional
def save_checkpoint(dire: str, def save_checkpoint(dire: str,
epoch: int, epoch: int,
model: torch.nn.Module, model: torch.nn.Module,
optimizer: torch.optim.Optimizer = None, optimizer: Optional[ColossalaiOptimizer] = None,
lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None, lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None,
*args, *args,
**kwargs): **kwargs):
@ -16,7 +19,7 @@ def save_checkpoint(dire: str,
dire (str): directory to save the checkpoint files. dire (str): directory to save the checkpoint files.
epoch (int): the number of epoch epoch (int): the number of epoch
model (torch.nn.Module): a torch module initialized by ColoInitContext model (torch.nn.Module): a torch module initialized by ColoInitContext
optimizer (torch.optim.Optimizer, optional): optimizers. Defaults to None. optimizer (ColossalaiOptimizer, optional): optimizers. Defaults to None.
lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional): lr schedule. Defaults to None. lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional): lr schedule. Defaults to None.
""" """
@ -41,11 +44,21 @@ def save_checkpoint(dire: str,
# delete the new dict # delete the new dict
del new_dict del new_dict
optim_state_copy = copy(optimizer.state_dict())
for k, v in optim_state_copy['state'].items():
for n, t in v.items():
if isinstance(t, ColoTensor):
t.to_replicate_()
if dist.get_rank() == 0:
model_state = {'epoch': epoch, 'optim': optim_state_copy}
torch.save(model_state, dire + '/epoch_{}_optim.pth'.format(epoch))
del optim_state_copy
def load_checkpoint(dire, def load_checkpoint(dire,
epoch: int, epoch: int,
model: torch.nn.Module, model: torch.nn.Module,
optimizer: torch.optim.Optimizer = None, optimizer: Optional[ColossalaiOptimizer] = None,
lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None, lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None,
*args, *args,
**kwargs): **kwargs):
@ -56,7 +69,7 @@ def load_checkpoint(dire,
epoch (int): _description_ epoch (int): _description_
rank (int): _description_ rank (int): _description_
model (torch.nn.Module): _description_ model (torch.nn.Module): _description_
optimizer (torch.optim.Optimizer, optional): _description_. Defaults to None. optimizer (ColossalaiOptimizer, optional): _description_. Defaults to None.
lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional): _description_. Defaults to None. lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional): _description_. Defaults to None.
""" """
@ -74,3 +87,24 @@ def load_checkpoint(dire,
for k, v in model.state_dict().items(): for k, v in model.state_dict().items():
if isinstance(v, ColoTensor): if isinstance(v, ColoTensor):
v.set_tensor_spec(*mapping[k]) v.set_tensor_spec(*mapping[k])
del mapping
mapping = dict()
for k, v in optimizer.state_dict()['state'].items():
for n, t in v.items():
if isinstance(t, ColoTensor):
mapping[(k, n)] = (t.dist_spec, t.compute_spec)
t.to_replicate_()
colo_checkpoint = torch.load(dire + '/epoch_{}_optim.pth'.format(epoch))
optimizer.load_state_dict(colo_checkpoint['optim'])
for k, v in optimizer.state_dict()['state'].items():
for n, t in v.items():
if isinstance(t, ColoTensor):
# skip key not in mapping.
# For Adam, if it dose not execute step() once, there will be not exp_avg and exp_avg_sq in optimizer
if (k, n) not in mapping:
continue
t.set_tensor_spec(*mapping[(k, n)])

View File

@ -77,6 +77,18 @@ def remove(path):
raise ValueError("file {} is not a file or dir.".format(path)) raise ValueError("file {} is not a file or dir.".format(path))
def compare_optims(optim1, optim2):
state1 = optim1.state_dict()['state']
state2 = optim2.state_dict()['state']
for k, p1 in state1.items():
if k not in state2:
continue
p2 = state2[k]
if isinstance(p1, ColoTensor):
assert isinstance(p2, ColoTensor)
assert torch.allclose(p1.to_replicate_(), p2.to_replicate_(), rtol=1e-3, atol=1e-1)
def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_scheduler, pg): def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_scheduler, pg):
get_components_func = non_distributed_component_funcs.get_callable(model_name) get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
@ -117,7 +129,10 @@ def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_sch
model_reload = model_reload.cuda() model_reload = model_reload.cuda()
model_reload.train() model_reload.train()
colo_optimizer = ColossalaiOptimizer(torch.optim.SGD(model.named_parameters(), r=0.1)) opt_class = torch.optim.Adam
colo_optimizer = ColossalaiOptimizer(opt_class(model.parameters(), lr=0.1))
colo_optimizer_reload = ColossalaiOptimizer(opt_class(model_reload.parameters(), lr=0.1))
run_reload = False
for i, (data, label) in enumerate(train_dataloader): for i, (data, label) in enumerate(train_dataloader):
@ -130,22 +145,35 @@ def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_sch
# Bcast rank0 data to all processes # Bcast rank0 data to all processes
if criterion: if criterion:
output = model(data) output = model(data)
output_reload = model_reload(data)
loss = criterion(output, label) loss = criterion(output, label)
loss_reload = criterion(output_reload, label)
else: else:
output = model(data, label) loss = model(data, label)
loss = output loss_reload = model_reload(data, label)
loss.backward() loss.backward()
colo_optimizer.step() loss_reload.backward()
if run_reload:
colo_optimizer_reload.zero_grad()
if criterion:
output_reload = model_reload(data)
loss_reload = criterion(output_reload, label)
else:
loss_reload = model_reload(data, label)
loss_reload.backward()
colo_optimizer_reload.step()
if i > 2: if i > 2:
break break
if not os.path.isdir('./checkpoint') and rank == 0: if not os.path.isdir('./checkpoint') and rank == 0:
os.mkdir('./checkpoint') os.mkdir('./checkpoint')
save_checkpoint('./checkpoint', 0, model, None, None) save_checkpoint('./checkpoint', 0, model, colo_optimizer, None)
dist.barrier()
load_checkpoint('./checkpoint', 0, model_reload, colo_optimizer_reload, None)
dist.barrier() dist.barrier()
load_checkpoint('./checkpoint', 0, model_reload, None, None)
# Since model is sharded, we merge them before param checking. # Since model is sharded, we merge them before param checking.
for p in model.parameters(): for p in model.parameters():
@ -155,7 +183,7 @@ def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_sch
p.to_replicate_() p.to_replicate_()
check_param_equal(model, model_reload) check_param_equal(model, model_reload)
compare_optims(colo_optimizer, colo_optimizer_reload)
if rank == 0: if rank == 0:
remove('./checkpoint') remove('./checkpoint')
@ -163,7 +191,7 @@ def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_sch
def run_dist(rank, world_size, port, use_ddp, use_mp_reload, test_scheduler): def run_dist(rank, world_size, port, use_ddp, use_mp_reload, test_scheduler):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
pg = ProcessGroup(tp_degree=world_size) pg = ProcessGroup(tp_degree=world_size)
for model_name in ['bert', 'simple_net']: for model_name in ['simple_net', 'bert']:
_run_checkpoint(model_name, _run_checkpoint(model_name,
init_1d_row_for_linear_weight_spec, init_1d_row_for_linear_weight_spec,
use_ddp, use_ddp,