From d26902645e7c4b7453838648d2fa52291a439b0a Mon Sep 17 00:00:00 2001 From: ver217 Date: Mon, 20 Jun 2022 10:51:47 +0800 Subject: [PATCH] [ddp] add save/load state dict for ColoDDP (#1127) * add save/load state dict for ColoDDP * add unit test * refactor unit test folder * polish unit test * rename unit test --- colossalai/nn/parallel/data_parallel.py | 222 +++++++++++++++++- .../test_ddp_ignore_params.py | 0 tests/test_ddp/test_ddp_state_dict.py | 66 ++++++ 3 files changed, 286 insertions(+), 2 deletions(-) rename tests/{test_utils => test_ddp}/test_ddp_ignore_params.py (100%) create mode 100644 tests/test_ddp/test_ddp_state_dict.py diff --git a/colossalai/nn/parallel/data_parallel.py b/colossalai/nn/parallel/data_parallel.py index 53d126a79..f88534bc3 100644 --- a/colossalai/nn/parallel/data_parallel.py +++ b/colossalai/nn/parallel/data_parallel.py @@ -1,4 +1,5 @@ import torch +import itertools import torch.distributed as dist from colossalai.core import global_context as gpc from colossalai.context import ParallelMode @@ -7,8 +8,14 @@ from colossalai.zero.utils.zero_hook_v2 import ZeROHookV2 from colossalai.tensor.chunk import TensorState, Chunk from colossalai.tensor.param_op_hook import ParamOpHookManager from colossalai.gemini.gemini_mgr import GeminiManager -from typing import Dict, Iterable +from typing import Dict, Iterable, List from colossalai.logging import get_dist_logger +from collections import OrderedDict +from colossalai.tensor.colo_parameter import ColoParameter +try: + from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, _IncompatibleKeys +except ImportError: + _EXTRA_STATE_KEY_SUFFIX = '_extra_state' def free_storage(data: torch.Tensor) -> None: @@ -122,6 +129,12 @@ class ColoDDP(torch.nn.Module): for p in params_to_ignore: p._ddp_to_ignore = True + def state_dict(self, destination=None, prefix='', keep_vars=False): + return self.module.state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars) + + def load_state_dict(self, state_dict: 'OrderedDict[str, torch.Tensor]', strict: bool = True): + return self.module.load_state_dict(state_dict, strict) + class ColoDDPV2(ColoDDP): @@ -130,7 +143,7 @@ class ColoDDPV2(ColoDDP): self.gemini_manager = gemini_manager self.chunk_manager = gemini_manager.chunk_manager self.param_op_hook = ZeROHookV2(gemini_manager) - self.fp32_params = [] + self.fp32_params: List[ColoParameter] = [] self.overflow_counter = 0 self.grads_device: Dict[torch.Tensor, torch.device] = {} self.chunk_manager.create_group('fp16_param', force_data_on_cuda=True) @@ -205,3 +218,208 @@ class ColoDDPV2(ColoDDP): def _set_chunk_grad_device(self, chunk: Chunk, device: torch.device) -> None: for tensor in chunk.get_tensors(): self.grads_device[tensor] = device + + def state_dict(self, destination=None, prefix='', keep_vars=False): + r"""Returns a dictionary containing a whole state of the module. + + Both parameters and persistent buffers (e.g. running averages) are + included. Keys are corresponding parameter and buffer names. + Parameters and buffers set to ``None`` are not included. + + Returns: + dict: + a dictionary containing a whole state of the module + + Example:: + + >>> module.state_dict().keys() + ['bias', 'weight'] + + """ + if destination is None: + destination = OrderedDict() + destination._metadata = OrderedDict() + destination._metadata[prefix[:-1]] = local_metadata = dict(version=self._version) + self._save_to_state_dict(destination, prefix, keep_vars) + for hook in self._state_dict_hooks.values(): + hook_result = hook(self, destination, prefix, local_metadata) + if hook_result is not None: + destination = hook_result + return destination + + def _save_to_state_dict(self, destination, prefix, keep_vars): + r"""Saves module state to `destination` dictionary, containing a state + of the module, but not its descendants. This is called on every + submodule in :meth:`~torch.nn.Module.state_dict`. + + In rare cases, subclasses can achieve class-specific behavior by + overriding this method with custom logic. + + Args: + destination (dict): a dict where state will be stored + prefix (str): the prefix for parameters and buffers used in this + module + """ + chunks = self.chunk_manager.get_chunks(self.fp32_params) + for chunk in chunks: + self.chunk_manager.access_chunk(chunk) + for (name, p), fp32_p in zip(self.named_parameters(), self.fp32_params): + if p is not None: + destination[prefix + name] = fp32_p.clone() if keep_vars else fp32_p.clone().detach() + for chunk in chunks: + self.chunk_manager.release_chunk(chunk) + for name, buf in self.named_buffers(): + if buf is not None and name not in self._non_persistent_buffers_set: + destination[prefix + name] = buf if keep_vars else buf.detach() + extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX + if getattr(self.__class__, "get_extra_state", + torch.nn.Module.get_extra_state) is not torch.nn.Module.get_extra_state: + destination[extra_state_key] = self.get_extra_state() + + def load_state_dict(self, state_dict: 'OrderedDict[str, torch.Tensor]', strict: bool = True): + r"""Copies parameters and buffers from :attr:`state_dict` into + this module and its descendants. If :attr:`strict` is ``True``, then + the keys of :attr:`state_dict` must exactly match the keys returned + by this module's :meth:`~torch.nn.Module.state_dict` function. + + Args: + state_dict (dict): a dict containing parameters and + persistent buffers. + strict (bool, optional): whether to strictly enforce that the keys + in :attr:`state_dict` match the keys returned by this module's + :meth:`~torch.nn.Module.state_dict` function. Default: ``True`` + + Returns: + ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields: + * **missing_keys** is a list of str containing the missing keys + * **unexpected_keys** is a list of str containing the unexpected keys + + Note: + If a parameter or buffer is registered as ``None`` and its corresponding key + exists in :attr:`state_dict`, :meth:`load_state_dict` will raise a + ``RuntimeError``. + """ + missing_keys: List[str] = [] + unexpected_keys: List[str] = [] + error_msgs: List[str] = [] + + # copy state_dict so _load_from_state_dict can modify it + metadata = getattr(state_dict, '_metadata', None) + state_dict = state_dict.copy() + if metadata is not None: + # mypy isn't aware that "_metadata" exists in state_dict + state_dict._metadata = metadata # type: ignore[attr-defined] + + prefix = '' + local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) + self._load_from_state_dict(state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) + + if strict: + if len(unexpected_keys) > 0: + error_msgs.insert( + 0, 'Unexpected key(s) in state_dict: {}. '.format(', '.join( + '"{}"'.format(k) for k in unexpected_keys))) + if len(missing_keys) > 0: + error_msgs.insert( + 0, 'Missing key(s) in state_dict: {}. '.format(', '.join('"{}"'.format(k) for k in missing_keys))) + + if len(error_msgs) > 0: + raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( + self.__class__.__name__, "\n\t".join(error_msgs))) + return _IncompatibleKeys(missing_keys, unexpected_keys) + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, + error_msgs): + r"""Copies parameters and buffers from :attr:`state_dict` into only + this module, but not its descendants. This is called on every submodule + in :meth:`~torch.nn.Module.load_state_dict`. Metadata saved for this + module in input :attr:`state_dict` is provided as :attr:`local_metadata`. + For state dicts without metadata, :attr:`local_metadata` is empty. + Subclasses can achieve class-specific backward compatible loading using + the version number at `local_metadata.get("version", None)`. + + .. note:: + :attr:`state_dict` is not the same object as the input + :attr:`state_dict` to :meth:`~torch.nn.Module.load_state_dict`. So + it can be modified. + + Args: + state_dict (dict): a dict containing parameters and + persistent buffers. + prefix (str): the prefix for parameters and buffers used in this + module + local_metadata (dict): a dict containing the metadata for this module. + See + strict (bool): whether to strictly enforce that the keys in + :attr:`state_dict` with :attr:`prefix` match the names of + parameters and buffers in this module + missing_keys (list of str): if ``strict=True``, add missing keys to + this list + unexpected_keys (list of str): if ``strict=True``, add unexpected + keys to this list + error_msgs (list of str): error messages should be added to this + list, and will be reported together in + :meth:`~torch.nn.Module.load_state_dict` + """ + for hook in self._load_state_dict_pre_hooks.values(): + hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) + + persistent_buffers = {k: v for k, v in self.named_buffers() if k not in self._non_persistent_buffers_set} + local_name_params = itertools.chain(self.named_parameters(), persistent_buffers.items()) + local_state = {k: v for k, v in local_name_params if v is not None} + + def load(name, dest_tensor, copy_func): + key = prefix + name + if key in state_dict: + input_param = state_dict[key] + # Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+ + if len(dest_tensor.shape) == 0 and len(input_param.shape) == 1: + input_param = input_param[0] + if input_param.shape != dest_tensor.shape: + # local shape should match the one in checkpoint + error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, ' + 'the shape in current model is {}.'.format(key, input_param.shape, + dest_tensor.shape)) + return + try: + with torch.no_grad(): + # self.chunk_manager.copy_tensor_to_chunk_slice(fp32_p, input_param) + copy_func(input_param) + except Exception as ex: + error_msgs.append('While copying the parameter named "{}", ' + 'whose dimensions in the model are {} and ' + 'whose dimensions in the checkpoint are {}, ' + 'an exception occurred : {}.'.format(key, dest_tensor.size(), input_param.size(), + ex.args)) + elif strict: + missing_keys.append(key) + + def load_fp32_p(fp32_p, data): + if fp32_p.storage().size() > 0: + self.chunk_manager.copy_tensor_to_chunk_slice(fp32_p, data) + + for (name, p), fp32_p in zip(self.named_parameters(), self.fp32_params): + if p is not None: + load(name, fp32_p, partial(load_fp32_p, fp32_p)) + self.chunk_manager.copy_chunk_group('fp16_param', 'fp32_param') + + for name, buf in persistent_buffers.items(): + if buf is not None: + load(name, buf, buf.copy_) + + extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX + if getattr(self.__class__, "set_extra_state", + torch.nn.Module.set_extra_state) is not torch.nn.Module.set_extra_state: + if extra_state_key in state_dict: + self.set_extra_state(state_dict[extra_state_key]) + elif strict: + missing_keys.append(extra_state_key) + elif strict and (extra_state_key in state_dict): + unexpected_keys.append(extra_state_key) + + if strict: + for key in state_dict.keys(): + if key.startswith(prefix) and key != extra_state_key: + input_name = key[len(prefix):] + if input_name not in local_state: + unexpected_keys.append(key) diff --git a/tests/test_utils/test_ddp_ignore_params.py b/tests/test_ddp/test_ddp_ignore_params.py similarity index 100% rename from tests/test_utils/test_ddp_ignore_params.py rename to tests/test_ddp/test_ddp_ignore_params.py diff --git a/tests/test_ddp/test_ddp_state_dict.py b/tests/test_ddp/test_ddp_state_dict.py new file mode 100644 index 000000000..37de68b81 --- /dev/null +++ b/tests/test_ddp/test_ddp_state_dict.py @@ -0,0 +1,66 @@ +import pytest +import colossalai +import torch +import torch.multiprocessing as mp +from colossalai.testing import rerun_if_address_is_in_use +from colossalai.utils.cuda import get_current_device +from colossalai.utils import free_port +from colossalai.utils.model.colo_init_context import ColoInitContext +from colossalai.tensor import ChunkManager +from functools import partial +from tests.components_to_test.registry import non_distributed_component_funcs +from colossalai.nn.parallel import ColoDDPV2, ColoDDP +from colossalai.gemini.gemini_mgr import GeminiManager +from typing import Callable +from collections import OrderedDict + + +def check_state_dict_equal(state_dict: OrderedDict, other_state_dict: OrderedDict): + for (k1, t1), (k2, t2) in zip(state_dict.items(), other_state_dict.items()): + assert k1 == k2 + assert torch.allclose(t1, t2, atol=1e-3, rtol=1e-3) + + +def init_ddp(module: torch.nn.Module) -> ColoDDP: + return ColoDDP(module) + + +def init_ddpv2(module: torch.nn.Module, use_chunk: bool = False, use_zero: bool = False) -> ColoDDPV2: + chunk_size = ChunkManager.search_chunk_size(module, 64, 4) if use_chunk else None + chunk_manager = ChunkManager(chunk_size, enable_distributed_storage=use_zero) + gemini_manager = GeminiManager('cuda', chunk_manager) + return ColoDDPV2(module, gemini_manager) + + +def run_state_dict(ddp_init_func: Callable[[torch.nn.Module], ColoDDP]): + get_components_func = non_distributed_component_funcs.get_callable('nested_model') + model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() + torch_model = model_builder().cuda() + with ColoInitContext(device=get_current_device()): + model = model_builder() + model = ddp_init_func(model) + torch_state_dict = torch_model.state_dict() + model.load_state_dict(torch_state_dict) + state_dict = model.state_dict() + check_state_dict_equal(torch_state_dict, state_dict) + + +def run_dist(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_state_dict(init_ddp) + run_state_dict(partial(init_ddpv2, use_chunk=False, use_zero=False)) + run_state_dict(partial(init_ddpv2, use_chunk=False, use_zero=True)) + run_state_dict(partial(init_ddpv2, use_chunk=True, use_zero=False)) + run_state_dict(partial(init_ddpv2, use_chunk=True, use_zero=True)) + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [1, 2]) +@rerun_if_address_is_in_use() +def test_state_dict(world_size): + run_func = partial(run_dist, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_state_dict(2)