[checkpointio] fix hybrid plugin model save (#6106)

This commit is contained in:
Hongxin Liu
2024-10-31 17:04:53 +08:00
committed by GitHub
parent 89a9a600bc
commit c2e8f61592
4 changed files with 41 additions and 38 deletions

View File

@@ -35,7 +35,7 @@ from colossalai.tensor.padded_tensor import (
to_unpadded_tensor,
)
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
from colossalai.utils import _cast_float, free_storage, is_ddp_ignored
from colossalai.utils import _cast_float, free_storage, get_non_persistent_buffers_set, is_ddp_ignored
from .chunk import Chunk, ChunkManager, TensorState, init_chunk_manager
from .gemini_hook import GeminiZeROHook
@@ -187,7 +187,7 @@ class GeminiDDP(ModelWrapper):
pin_memory=pin_memory,
)
super().__init__(module)
self._non_persistent_buffers_set = self._get_non_persistent_buffers_set(module)
self._non_persistent_buffers_set = get_non_persistent_buffers_set(module)
self._cast_buffers()
# register grad hook
@@ -257,36 +257,6 @@ class GeminiDDP(ModelWrapper):
for p in params_to_ignore:
p._ddp_to_ignore = True
def _get_non_persistent_buffers_set(
self, module, memo: Optional[Set[nn.Module]] = None, prefix: str = "", remove_duplicate: bool = True
):
r"""
Args:
memo: a memo to store the set of modules already added to the result
prefix: a prefix that will be added to the name of the module
remove_duplicate: whether to remove the duplicated module instances in the result
or not
"""
if memo is None:
memo = set()
self_non_persistent_set = set()
if module not in memo:
if remove_duplicate:
memo.add(module)
self_non_persistent_set = set(
map(lambda key: prefix + ("." if prefix else "") + key, module._non_persistent_buffers_set)
)
for name, sub_module in module._modules.items():
if sub_module is None:
continue
submodule_prefix = prefix + ("." if prefix else "") + name
child_non_persistent_set = self._get_non_persistent_buffers_set(
sub_module, memo, submodule_prefix, remove_duplicate
)
self_non_persistent_set = set.union(self_non_persistent_set, child_non_persistent_set)
return self_non_persistent_set
def _post_forward(self):
"""This function is only triggered for inference."""
access_list = list(self.chunk_manager.accessed_chunks)