From f92c100ddd4a93a72f710dc476936c6890b8bffe Mon Sep 17 00:00:00 2001 From: HELSON Date: Tue, 19 Jul 2022 14:15:28 +0800 Subject: [PATCH] [checkpoint] use gather_tensor in checkpoint and update its unit test (#1339) --- colossalai/tensor/colo_tensor.py | 2 +- colossalai/tensor/process_group.py | 9 ++ .../utils/checkpoint/module_checkpoint.py | 136 ++++++++++-------- colossalai/utils/checkpoint/utils.py | 50 +++++++ tests/test_utils/test_colo_checkpoint.py | 56 ++++---- .../test_utils/test_colo_checkpoint_tools.py | 47 ++++++ 6 files changed, 209 insertions(+), 91 deletions(-) create mode 100644 colossalai/utils/checkpoint/utils.py create mode 100644 tests/test_utils/test_colo_checkpoint_tools.py diff --git a/colossalai/tensor/colo_tensor.py b/colossalai/tensor/colo_tensor.py index f5f0b2505..51cc4619a 100644 --- a/colossalai/tensor/colo_tensor.py +++ b/colossalai/tensor/colo_tensor.py @@ -262,7 +262,7 @@ class ColoTensor(torch.Tensor): replicated_t = self.redistribute(dist_spec=ReplicaSpec()) return replicated_t.view(*args) - def size_global(self, args: Optional[int] = None): + def size_global(self, args: Optional[int] = None) -> torch.Size: """override the torch buildin size() the shape passed in must be in a replicate placement. Returns: diff --git a/colossalai/tensor/process_group.py b/colossalai/tensor/process_group.py index 12fba646d..a9c04244a 100644 --- a/colossalai/tensor/process_group.py +++ b/colossalai/tensor/process_group.py @@ -141,9 +141,18 @@ class ProcessGroup: def rank(self): return self._rank + def ranks_in_group(self): + return self._rank_list + def world_size(self): return self._world_size + def tp_rank_list(self): + return self._tp_rank_list + + def dp_rank_list(self): + return self._dp_rank_list + def tp_local_rank(self): return self._rank % self._tp_degree diff --git a/colossalai/utils/checkpoint/module_checkpoint.py b/colossalai/utils/checkpoint/module_checkpoint.py index 81370ad0f..8e9654e6f 100644 --- a/colossalai/utils/checkpoint/module_checkpoint.py +++ b/colossalai/utils/checkpoint/module_checkpoint.py @@ -1,8 +1,8 @@ import torch import torch.distributed as dist -from colossalai.tensor import ColoTensor, DistSpecManager +from colossalai.tensor import ColoTensor from colossalai.nn.optimizer import ColossalaiOptimizer -from copy import copy +from colossalai.utils.checkpoint.utils import gather_tensor, scatter_tensor from typing import Optional @@ -22,37 +22,52 @@ def save_checkpoint(dire: str, optimizer (ColossalaiOptimizer, optional): optimizers. Defaults to None. lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional): lr schedule. Defaults to None. """ + rank = dist.get_rank() + model_state = model.state_dict() + # save the dist context about the tensors in a new dict, while still maintain the original dict. + for k, v in model_state.items(): + if isinstance(v, ColoTensor): + gather_tensor(v) # gather shared tensors to rank0 + # don't recover tensors in rank0, since the dict is only a copy of model + + if rank == 0: + # sanity check + for k, v in model_state.items(): + if isinstance(v, ColoTensor): + assert v.save_ready + assert v.is_replicate() + delattr(v, 'save_ready') + # model saving + save_state = {'epoch': epoch, 'model': model_state} + torch.save(save_state, dire + '/epoch_{}_model.pth'.format(epoch)) + + # delete old dicts + del model_state + # synchronize all the processes + dist.barrier() mapping = dict() - new_dict = dict() - - # save the dist context about the tensors in a new dict, while still maintain the original dict. - for k, v in model.state_dict().items(): - if isinstance(v, ColoTensor): - mapping[k] = (v.dist_spec, v.compute_spec) - new_dict[k] = v.to_replicate().detach() - else: - new_dict[k] = v - if dist.get_rank() == 0: - for k, v in new_dict.items(): - if isinstance(v, ColoTensor): - assert v.is_replicate() - - model_state = {'epoch': epoch, 'model': new_dict} - torch.save(model_state, dire + '/epoch_{}_model.pth'.format(epoch)) - - # delete the new dict - del new_dict - - optim_state_copy = copy(optimizer.state_dict()) - for k, v in optim_state_copy['state'].items(): + optim_state = optimizer.state_dict() + for k, v in optim_state['state'].items(): for n, t in v.items(): if isinstance(t, ColoTensor): - t.to_replicate_() - if dist.get_rank() == 0: - model_state = {'epoch': epoch, 'optim': optim_state_copy} - torch.save(model_state, dire + '/epoch_{}_optim.pth'.format(epoch)) - del optim_state_copy + mapping[(k, n)] = t.dist_spec + gather_tensor(t) + + if rank == 0: + save_state = {'epoch': epoch, 'optim': optim_state} + torch.save(save_state, dire + '/epoch_{}_optim.pth'.format(epoch)) + # recover colo tensors in rank0 + for k, v in optimizer.state_dict()['state'].items(): + for n, t in v.items(): + if isinstance(t, ColoTensor): + assert hasattr(t, 'save_ready') + t.set_dist_spec(mapping[(k, n)]) + delattr(t, 'save_ready') + + del optim_state + del mapping + dist.barrier() def load_checkpoint(dire, @@ -72,39 +87,42 @@ def load_checkpoint(dire, optimizer (ColossalaiOptimizer, optional): _description_. Defaults to None. lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional): _description_. Defaults to None. """ + rank = dist.get_rank() + mapping = dict() + for n, p in model.named_parameters(): + if isinstance(p, ColoTensor): + mapping[n] = p.dist_spec + gather_tensor(p) + + if rank == 0: + load_state = torch.load(dire + '/epoch_{}_model.pth'.format(epoch)) + model.load_state_dict(load_state['model']) + dist.barrier() + + # scatter loaded parameters + for n, p in model.named_parameters(): + if isinstance(p, ColoTensor): + scatter_tensor(p, mapping[n]) + if rank == 0: + assert hasattr(p, 'save_ready') + delattr(p, 'save_ready') + del mapping mapping = dict() - for k, v in model.state_dict().items(): - if isinstance(v, ColoTensor): - mapping[k] = (v.dist_spec, v.compute_spec) - v.to_replicate_() + for k, v in optimizer.state_dict()['state'].items(): + for n, t in v.items(): + if isinstance(t, ColoTensor): + mapping[(k, n)] = t.dist_spec + gather_tensor(t) - model_state = torch.load(dire + '/epoch_{}_model.pth'.format(epoch)) - model.load_state_dict(model_state['model']) + if rank == 0: + colo_checkpoint = torch.load(dire + '/epoch_{}_optim.pth'.format(epoch)) + optimizer.load_state_dict(colo_checkpoint['optim']) + dist.barrier() - # reset tensors to original dist spec. - with DistSpecManager.no_grad(): - for k, v in model.state_dict().items(): - if isinstance(v, ColoTensor): - v.set_tensor_spec(*mapping[k]) + for k, v in optimizer.state_dict()['state'].items(): + for n, t in v.items(): + if isinstance(t, ColoTensor): + scatter_tensor(t, mapping[(k, n)]) del mapping - mapping = dict() - - for k, v in optimizer.state_dict()['state'].items(): - for n, t in v.items(): - if isinstance(t, ColoTensor): - mapping[(k, n)] = (t.dist_spec, t.compute_spec) - t.to_replicate_() - - colo_checkpoint = torch.load(dire + '/epoch_{}_optim.pth'.format(epoch)) - optimizer.load_state_dict(colo_checkpoint['optim']) - - for k, v in optimizer.state_dict()['state'].items(): - for n, t in v.items(): - if isinstance(t, ColoTensor): - # skip key not in mapping. - # For Adam, if it dose not execute step() once, there will be not exp_avg and exp_avg_sq in optimizer - if (k, n) not in mapping: - continue - t.set_tensor_spec(*mapping[(k, n)]) diff --git a/colossalai/utils/checkpoint/utils.py b/colossalai/utils/checkpoint/utils.py new file mode 100644 index 000000000..e018d3711 --- /dev/null +++ b/colossalai/utils/checkpoint/utils.py @@ -0,0 +1,50 @@ +import torch +import torch.distributed as dist +from colossalai.tensor import ColoTensor, ColoTensorSpec +from colossalai.tensor.distspec import _DistSpec + + +def gather_tensor(colo_tensor: ColoTensor) -> None: + """Make colo_tensor replicated when the rank is 0 + """ + if not colo_tensor.is_replicate(): + pg = colo_tensor.get_process_group() + # for the group which contains rank 0 + if pg.tp_rank_list()[0] == 0: + old_dist_spec = colo_tensor.dist_spec + colo_tensor.to_replicate_() + if dist.get_rank() != 0: + colo_tensor.set_dist_spec(old_dist_spec) + + # synchronize all processes for unexpected problems + dist.barrier() + + if dist.get_rank() == 0: + setattr(colo_tensor, 'save_ready', True) # set saving signitrue + + +def scatter_tensor(colo_tensor: ColoTensor, dist_spec: _DistSpec) -> None: + """Reversal operation of `gather_tensor`. + """ + if dist_spec.placement == 'r': + dist.broadcast(colo_tensor.data, 0) + else: + global_size = colo_tensor.size_global() + + if dist.get_rank() == 0: + entire_data = colo_tensor.data + else: + entire_data = torch.empty(global_size, device=colo_tensor.device) + dist.broadcast(entire_data, 0) + + if dist.get_rank() == 0: + colo_tensor.set_dist_spec(dist_spec) + else: + rep_tensor = ColoTensor(entire_data, ColoTensorSpec( + pg=colo_tensor.get_process_group(), + compute_attr=colo_tensor.compute_spec)) + rep_tensor.set_dist_spec(dist_spec) + with torch.no_grad(): + colo_tensor.data.copy_(rep_tensor.data) + # synchronize all processes for unexpected problems + dist.barrier() diff --git a/tests/test_utils/test_colo_checkpoint.py b/tests/test_utils/test_colo_checkpoint.py index 524a39be1..d25b17f10 100644 --- a/tests/test_utils/test_colo_checkpoint.py +++ b/tests/test_utils/test_colo_checkpoint.py @@ -1,6 +1,7 @@ import os, shutil import torch import pytest +from copy import deepcopy from functools import partial import torch.multiprocessing as mp @@ -15,8 +16,7 @@ from colossalai.testing import rerun_if_address_is_in_use from colossalai.utils.cuda import get_current_device from colossalai.utils import free_port from colossalai.utils.model.colo_init_context import ColoInitContext -from colossalai.tensor import ComputePattern, ComputeSpec, ColoTensor, ShardSpec, ProcessGroup, DistSpecManager, ReplicaSpec -from colossalai.nn.parallel.data_parallel import ColoDDP +from colossalai.tensor import ComputePattern, ComputeSpec, ColoTensor, ShardSpec, ProcessGroup from colossalai.utils.checkpoint import save_checkpoint, load_checkpoint from colossalai.nn.optimizer import ColossalaiOptimizer @@ -63,8 +63,8 @@ def init_1d_row_for_linear_weight_spec(model, pg: ProcessGroup): def check_param_equal(model, torch_model): - for p, torch_p in zip(model.parameters(), torch_model.parameters()): - assert torch.allclose(torch_p, p, rtol=1e-3, atol=1e-1) + for (n, p), (tn, tp) in zip(model.named_parameters(), torch_model.named_parameters()): + assert torch.all(p.data == tp.data), "{} went wrong.\n {} vs {}\n{}".format(n, p, tp, p.shape) def remove(path): @@ -84,9 +84,13 @@ def compare_optims(optim1, optim2): if k not in state2: continue p2 = state2[k] - if isinstance(p1, ColoTensor): - assert isinstance(p2, ColoTensor) - assert torch.allclose(p1.to_replicate_(), p2.to_replicate_(), rtol=1e-3, atol=1e-1) + for n, t1 in p1.items(): + if n not in p2: + continue + t2 = p2[n] + if isinstance(t1, ColoTensor): + assert isinstance(t2, ColoTensor) + assert torch.allclose(t1, t2, rtol=0, atol=0) def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_scheduler, pg): @@ -99,7 +103,6 @@ def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_sch # set_seed(1) with ColoInitContext(device=get_current_device()): model = model_builder(checkpoint=True) - model_reload = model_builder(checkpoint=True) if use_mp_reload: if 'bert' == model_name: @@ -119,25 +122,26 @@ def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_sch elif 'token_type_embeddings' in name and 'weight' in name: init_1d_col_embedding(p, pg) elif p.process_group.tp_world_size() == 1: - p.redistribute(ReplicaSpec(), pg) + p.set_process_group(pg) elif "simple_net" == model_name: init_spec_func(model, pg) + model_reload = deepcopy(model) model = model.cuda() - model.train() + model.eval() model_reload = model_reload.cuda() - model_reload.train() + model_reload.eval() opt_class = torch.optim.Adam colo_optimizer = ColossalaiOptimizer(opt_class(model.parameters(), lr=0.1)) colo_optimizer_reload = ColossalaiOptimizer(opt_class(model_reload.parameters(), lr=0.1)) - run_reload = False for i, (data, label) in enumerate(train_dataloader): # Zero grad colo_optimizer.zero_grad() + colo_optimizer_reload.zero_grad() data = data.to(get_current_device()) label = label.to(get_current_device()) @@ -155,43 +159,33 @@ def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_sch loss.backward() loss_reload.backward() - if run_reload: - colo_optimizer_reload.zero_grad() - if criterion: - output_reload = model_reload(data) - loss_reload = criterion(output_reload, label) - else: - loss_reload = model_reload(data, label) - loss_reload.backward() - colo_optimizer_reload.step() + colo_optimizer.step() + colo_optimizer_reload.step() if i > 2: break if not os.path.isdir('./checkpoint') and rank == 0: os.mkdir('./checkpoint') + dist.barrier() + save_checkpoint('./checkpoint', 0, model, colo_optimizer, None) - dist.barrier() load_checkpoint('./checkpoint', 0, model_reload, colo_optimizer_reload, None) - dist.barrier() - - # Since model is sharded, we merge them before param checking. - for p in model.parameters(): - p.to_replicate_() - - for p in model_reload.parameters(): - p.to_replicate_() check_param_equal(model, model_reload) compare_optims(colo_optimizer, colo_optimizer_reload) + if rank == 0: remove('./checkpoint') + dist.barrier() def run_dist(rank, world_size, port, use_ddp, use_mp_reload, test_scheduler): colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') pg = ProcessGroup(tp_degree=world_size) - for model_name in ['simple_net', 'bert']: + # TODO(haichen) add BERT in the test + # the data loader of BERT is in DDP mode, causing the input data is not replicated in the TP context + for model_name in ['simple_net']: _run_checkpoint(model_name, init_1d_row_for_linear_weight_spec, use_ddp, diff --git a/tests/test_utils/test_colo_checkpoint_tools.py b/tests/test_utils/test_colo_checkpoint_tools.py new file mode 100644 index 000000000..551886f25 --- /dev/null +++ b/tests/test_utils/test_colo_checkpoint_tools.py @@ -0,0 +1,47 @@ +import torch +import pytest +from functools import partial + +import torch.multiprocessing as mp +import torch.distributed as dist + +import colossalai +from colossalai.testing import rerun_if_address_is_in_use +from colossalai.utils.cuda import get_current_device +from colossalai.utils import free_port +from colossalai.tensor import ComputePattern, ComputeSpec, ColoTensor, ShardSpec, ProcessGroup, ColoTensorSpec +from colossalai.utils.checkpoint.utils import gather_tensor, scatter_tensor +from tests.test_tensor._utils import tensor_shard_equal + + +def run_dist(rank, world_size, port, dp_degree, tp_degree): + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + pg = ProcessGroup(dp_degree=dp_degree, tp_degree=tp_degree) + x = torch.randn(4, 4, device=get_current_device()) + param = ColoTensor(torch.nn.Parameter(x), spec=ColoTensorSpec(pg)) + spec = ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D) + param.set_tensor_spec(*spec) + + gather_tensor(param) + if dist.get_rank() == 0: + assert torch.allclose(x, param.data, rtol=0, atol=0) + else: + assert tensor_shard_equal(x, param.data, pg.tp_local_rank(), pg.tp_world_size()) + dist.barrier() + + scatter_tensor(param, spec[0]) + assert tensor_shard_equal(x, param.data, pg.tp_local_rank(), pg.tp_world_size()) + assert param.requires_grad is True + dist.barrier() + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [4]) +@rerun_if_address_is_in_use() +def test_checkpoint(world_size): + run_func = partial(run_dist, world_size=world_size, port=free_port(), dp_degree=2, tp_degree=world_size // 2) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_checkpoint(world_size=4)