diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py index e6abf59e3..79bb33dca 100644 --- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py +++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py @@ -21,7 +21,7 @@ from colossalai.tensor.padded_tensor import ( to_padded_tensor, to_unpadded_tensor, ) -from colossalai.utils import get_current_device +from colossalai.utils import get_current_device, get_non_persistent_buffers_set from .general_checkpoint_io import GeneralCheckpointIO from .index_file import CheckpointIndexFile @@ -105,8 +105,9 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): yield block, block_size # Save buffers. + non_persist_buffers_set = get_non_persistent_buffers_set(model) for name, buf in model.named_buffers(): - if buf is not None and name not in model._non_persistent_buffers_set: + if buf is not None and name not in non_persist_buffers_set: buffer = buf if keep_vars else buf.detach() block, block_size = state_dict_sharder.append_param(prefix + name, buffer) if block is not None: @@ -352,9 +353,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): _load(name) # Load buffers. - non_persistent_buffers = set() - for n, m in model.named_modules(): - non_persistent_buffers |= set(".".join((n, b)) for b in m._non_persistent_buffers_set) + non_persistent_buffers = get_non_persistent_buffers_set(model) for name, buf in model.named_buffers(): if buf is not None and name not in non_persistent_buffers: _load(name) diff --git a/colossalai/utils/__init__.py b/colossalai/utils/__init__.py index cdba46709..1605a5f4e 100644 --- a/colossalai/utils/__init__.py +++ b/colossalai/utils/__init__.py @@ -5,6 +5,7 @@ from .common import ( ensure_path_exists, free_storage, get_current_device, + get_non_persistent_buffers_set, is_ddp_ignored, set_seed, ) @@ -25,4 +26,5 @@ __all__ = [ "set_seed", "get_current_device", "is_ddp_ignored", + "get_non_persistent_buffers_set", ] diff --git a/colossalai/utils/common.py b/colossalai/utils/common.py index 4a1889eb5..0863a812b 100644 --- a/colossalai/utils/common.py +++ b/colossalai/utils/common.py @@ -5,10 +5,11 @@ import os import random from contextlib import contextmanager from pathlib import Path -from typing import Callable +from typing import Callable, Optional, Set import numpy as np import torch +import torch.nn as nn from colossalai.accelerator import get_accelerator @@ -76,3 +77,34 @@ def set_seed(seed): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) + + +def get_non_persistent_buffers_set( + module, memo: Optional[Set[nn.Module]] = None, prefix: str = "", remove_duplicate: bool = True +): + r""" + Args: + memo: a memo to store the set of modules already added to the result + prefix: a prefix that will be added to the name of the module + remove_duplicate: whether to remove the duplicated module instances in the result + or not + """ + + if memo is None: + memo = set() + self_non_persistent_set = set() + if module not in memo: + if remove_duplicate: + memo.add(module) + self_non_persistent_set = set( + map(lambda key: prefix + ("." if prefix else "") + key, module._non_persistent_buffers_set) + ) + for name, sub_module in module._modules.items(): + if sub_module is None: + continue + submodule_prefix = prefix + ("." if prefix else "") + name + child_non_persistent_set = get_non_persistent_buffers_set( + sub_module, memo, submodule_prefix, remove_duplicate + ) + self_non_persistent_set = set.union(self_non_persistent_set, child_non_persistent_set) + return self_non_persistent_set diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py index dbaae6610..680ff07fd 100644 --- a/colossalai/zero/gemini/gemini_ddp.py +++ b/colossalai/zero/gemini/gemini_ddp.py @@ -35,7 +35,7 @@ from colossalai.tensor.padded_tensor import ( to_unpadded_tensor, ) from colossalai.tensor.param_op_hook import ColoParamOpHookManager -from colossalai.utils import _cast_float, free_storage, is_ddp_ignored +from colossalai.utils import _cast_float, free_storage, get_non_persistent_buffers_set, is_ddp_ignored from .chunk import Chunk, ChunkManager, TensorState, init_chunk_manager from .gemini_hook import GeminiZeROHook @@ -187,7 +187,7 @@ class GeminiDDP(ModelWrapper): pin_memory=pin_memory, ) super().__init__(module) - self._non_persistent_buffers_set = self._get_non_persistent_buffers_set(module) + self._non_persistent_buffers_set = get_non_persistent_buffers_set(module) self._cast_buffers() # register grad hook @@ -257,36 +257,6 @@ class GeminiDDP(ModelWrapper): for p in params_to_ignore: p._ddp_to_ignore = True - def _get_non_persistent_buffers_set( - self, module, memo: Optional[Set[nn.Module]] = None, prefix: str = "", remove_duplicate: bool = True - ): - r""" - Args: - memo: a memo to store the set of modules already added to the result - prefix: a prefix that will be added to the name of the module - remove_duplicate: whether to remove the duplicated module instances in the result - or not - """ - - if memo is None: - memo = set() - self_non_persistent_set = set() - if module not in memo: - if remove_duplicate: - memo.add(module) - self_non_persistent_set = set( - map(lambda key: prefix + ("." if prefix else "") + key, module._non_persistent_buffers_set) - ) - for name, sub_module in module._modules.items(): - if sub_module is None: - continue - submodule_prefix = prefix + ("." if prefix else "") + name - child_non_persistent_set = self._get_non_persistent_buffers_set( - sub_module, memo, submodule_prefix, remove_duplicate - ) - self_non_persistent_set = set.union(self_non_persistent_set, child_non_persistent_set) - return self_non_persistent_set - def _post_forward(self): """This function is only triggered for inference.""" access_list = list(self.chunk_manager.accessed_chunks)