From 9e4c6449b0ae314a9e2b6668895a64f52be48abb Mon Sep 17 00:00:00 2001 From: Jiarui Fang Date: Fri, 15 Jul 2022 09:52:55 +0800 Subject: [PATCH] [checkpoint] add ColoOptimizer checkpointing (#1316) --- .../nn/optimizer/colossalai_optimizer.py | 3 -- .../utils/checkpoint/module_checkpoint.py | 42 ++++++++++++++++-- tests/test_utils/test_colo_checkpoint.py | 44 +++++++++++++++---- 3 files changed, 74 insertions(+), 15 deletions(-) diff --git a/colossalai/nn/optimizer/colossalai_optimizer.py b/colossalai/nn/optimizer/colossalai_optimizer.py index fb0c43903..34f5a9541 100644 --- a/colossalai/nn/optimizer/colossalai_optimizer.py +++ b/colossalai/nn/optimizer/colossalai_optimizer.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - import torch import torch.nn as nn from torch import Tensor diff --git a/colossalai/utils/checkpoint/module_checkpoint.py b/colossalai/utils/checkpoint/module_checkpoint.py index 119d719b2..81370ad0f 100644 --- a/colossalai/utils/checkpoint/module_checkpoint.py +++ b/colossalai/utils/checkpoint/module_checkpoint.py @@ -1,12 +1,15 @@ import torch import torch.distributed as dist from colossalai.tensor import ColoTensor, DistSpecManager +from colossalai.nn.optimizer import ColossalaiOptimizer +from copy import copy +from typing import Optional def save_checkpoint(dire: str, epoch: int, model: torch.nn.Module, - optimizer: torch.optim.Optimizer = None, + optimizer: Optional[ColossalaiOptimizer] = None, lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None, *args, **kwargs): @@ -16,7 +19,7 @@ def save_checkpoint(dire: str, dire (str): directory to save the checkpoint files. epoch (int): the number of epoch model (torch.nn.Module): a torch module initialized by ColoInitContext - optimizer (torch.optim.Optimizer, optional): optimizers. Defaults to None. + optimizer (ColossalaiOptimizer, optional): optimizers. Defaults to None. lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional): lr schedule. Defaults to None. """ @@ -41,11 +44,21 @@ def save_checkpoint(dire: str, # delete the new dict del new_dict + optim_state_copy = copy(optimizer.state_dict()) + for k, v in optim_state_copy['state'].items(): + for n, t in v.items(): + if isinstance(t, ColoTensor): + t.to_replicate_() + if dist.get_rank() == 0: + model_state = {'epoch': epoch, 'optim': optim_state_copy} + torch.save(model_state, dire + '/epoch_{}_optim.pth'.format(epoch)) + del optim_state_copy + def load_checkpoint(dire, epoch: int, model: torch.nn.Module, - optimizer: torch.optim.Optimizer = None, + optimizer: Optional[ColossalaiOptimizer] = None, lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None, *args, **kwargs): @@ -56,7 +69,7 @@ def load_checkpoint(dire, epoch (int): _description_ rank (int): _description_ model (torch.nn.Module): _description_ - optimizer (torch.optim.Optimizer, optional): _description_. Defaults to None. + optimizer (ColossalaiOptimizer, optional): _description_. Defaults to None. lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional): _description_. Defaults to None. """ @@ -74,3 +87,24 @@ def load_checkpoint(dire, for k, v in model.state_dict().items(): if isinstance(v, ColoTensor): v.set_tensor_spec(*mapping[k]) + + del mapping + mapping = dict() + + for k, v in optimizer.state_dict()['state'].items(): + for n, t in v.items(): + if isinstance(t, ColoTensor): + mapping[(k, n)] = (t.dist_spec, t.compute_spec) + t.to_replicate_() + + colo_checkpoint = torch.load(dire + '/epoch_{}_optim.pth'.format(epoch)) + optimizer.load_state_dict(colo_checkpoint['optim']) + + for k, v in optimizer.state_dict()['state'].items(): + for n, t in v.items(): + if isinstance(t, ColoTensor): + # skip key not in mapping. + # For Adam, if it dose not execute step() once, there will be not exp_avg and exp_avg_sq in optimizer + if (k, n) not in mapping: + continue + t.set_tensor_spec(*mapping[(k, n)]) diff --git a/tests/test_utils/test_colo_checkpoint.py b/tests/test_utils/test_colo_checkpoint.py index edc463b0d..524a39be1 100644 --- a/tests/test_utils/test_colo_checkpoint.py +++ b/tests/test_utils/test_colo_checkpoint.py @@ -77,6 +77,18 @@ def remove(path): raise ValueError("file {} is not a file or dir.".format(path)) +def compare_optims(optim1, optim2): + state1 = optim1.state_dict()['state'] + state2 = optim2.state_dict()['state'] + for k, p1 in state1.items(): + if k not in state2: + continue + p2 = state2[k] + if isinstance(p1, ColoTensor): + assert isinstance(p2, ColoTensor) + assert torch.allclose(p1.to_replicate_(), p2.to_replicate_(), rtol=1e-3, atol=1e-1) + + def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_scheduler, pg): get_components_func = non_distributed_component_funcs.get_callable(model_name) model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() @@ -117,7 +129,10 @@ def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_sch model_reload = model_reload.cuda() model_reload.train() - colo_optimizer = ColossalaiOptimizer(torch.optim.SGD(model.named_parameters(), r=0.1)) + opt_class = torch.optim.Adam + colo_optimizer = ColossalaiOptimizer(opt_class(model.parameters(), lr=0.1)) + colo_optimizer_reload = ColossalaiOptimizer(opt_class(model_reload.parameters(), lr=0.1)) + run_reload = False for i, (data, label) in enumerate(train_dataloader): @@ -130,22 +145,35 @@ def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_sch # Bcast rank0 data to all processes if criterion: output = model(data) + output_reload = model_reload(data) loss = criterion(output, label) + loss_reload = criterion(output_reload, label) else: - output = model(data, label) - loss = output + loss = model(data, label) + loss_reload = model_reload(data, label) loss.backward() - colo_optimizer.step() + loss_reload.backward() + + if run_reload: + colo_optimizer_reload.zero_grad() + if criterion: + output_reload = model_reload(data) + loss_reload = criterion(output_reload, label) + else: + loss_reload = model_reload(data, label) + loss_reload.backward() + colo_optimizer_reload.step() if i > 2: break if not os.path.isdir('./checkpoint') and rank == 0: os.mkdir('./checkpoint') - save_checkpoint('./checkpoint', 0, model, None, None) + save_checkpoint('./checkpoint', 0, model, colo_optimizer, None) + dist.barrier() + load_checkpoint('./checkpoint', 0, model_reload, colo_optimizer_reload, None) dist.barrier() - load_checkpoint('./checkpoint', 0, model_reload, None, None) # Since model is sharded, we merge them before param checking. for p in model.parameters(): @@ -155,7 +183,7 @@ def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_sch p.to_replicate_() check_param_equal(model, model_reload) - + compare_optims(colo_optimizer, colo_optimizer_reload) if rank == 0: remove('./checkpoint') @@ -163,7 +191,7 @@ def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_sch def run_dist(rank, world_size, port, use_ddp, use_mp_reload, test_scheduler): colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') pg = ProcessGroup(tp_degree=world_size) - for model_name in ['bert', 'simple_net']: + for model_name in ['simple_net', 'bert']: _run_checkpoint(model_name, init_1d_row_for_linear_weight_spec, use_ddp,