mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-01 03:45:27 +00:00
[checkpointio] fix hybrid plugin model save (#6106)
This commit is contained in:
parent
89a9a600bc
commit
c2e8f61592
@ -21,7 +21,7 @@ from colossalai.tensor.padded_tensor import (
|
|||||||
to_padded_tensor,
|
to_padded_tensor,
|
||||||
to_unpadded_tensor,
|
to_unpadded_tensor,
|
||||||
)
|
)
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device, get_non_persistent_buffers_set
|
||||||
|
|
||||||
from .general_checkpoint_io import GeneralCheckpointIO
|
from .general_checkpoint_io import GeneralCheckpointIO
|
||||||
from .index_file import CheckpointIndexFile
|
from .index_file import CheckpointIndexFile
|
||||||
@ -105,8 +105,9 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||||||
yield block, block_size
|
yield block, block_size
|
||||||
|
|
||||||
# Save buffers.
|
# Save buffers.
|
||||||
|
non_persist_buffers_set = get_non_persistent_buffers_set(model)
|
||||||
for name, buf in model.named_buffers():
|
for name, buf in model.named_buffers():
|
||||||
if buf is not None and name not in model._non_persistent_buffers_set:
|
if buf is not None and name not in non_persist_buffers_set:
|
||||||
buffer = buf if keep_vars else buf.detach()
|
buffer = buf if keep_vars else buf.detach()
|
||||||
block, block_size = state_dict_sharder.append_param(prefix + name, buffer)
|
block, block_size = state_dict_sharder.append_param(prefix + name, buffer)
|
||||||
if block is not None:
|
if block is not None:
|
||||||
@ -352,9 +353,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||||||
_load(name)
|
_load(name)
|
||||||
|
|
||||||
# Load buffers.
|
# Load buffers.
|
||||||
non_persistent_buffers = set()
|
non_persistent_buffers = get_non_persistent_buffers_set(model)
|
||||||
for n, m in model.named_modules():
|
|
||||||
non_persistent_buffers |= set(".".join((n, b)) for b in m._non_persistent_buffers_set)
|
|
||||||
for name, buf in model.named_buffers():
|
for name, buf in model.named_buffers():
|
||||||
if buf is not None and name not in non_persistent_buffers:
|
if buf is not None and name not in non_persistent_buffers:
|
||||||
_load(name)
|
_load(name)
|
||||||
|
@ -5,6 +5,7 @@ from .common import (
|
|||||||
ensure_path_exists,
|
ensure_path_exists,
|
||||||
free_storage,
|
free_storage,
|
||||||
get_current_device,
|
get_current_device,
|
||||||
|
get_non_persistent_buffers_set,
|
||||||
is_ddp_ignored,
|
is_ddp_ignored,
|
||||||
set_seed,
|
set_seed,
|
||||||
)
|
)
|
||||||
@ -25,4 +26,5 @@ __all__ = [
|
|||||||
"set_seed",
|
"set_seed",
|
||||||
"get_current_device",
|
"get_current_device",
|
||||||
"is_ddp_ignored",
|
"is_ddp_ignored",
|
||||||
|
"get_non_persistent_buffers_set",
|
||||||
]
|
]
|
||||||
|
@ -5,10 +5,11 @@ import os
|
|||||||
import random
|
import random
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable
|
from typing import Callable, Optional, Set
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
from colossalai.accelerator import get_accelerator
|
from colossalai.accelerator import get_accelerator
|
||||||
|
|
||||||
@ -76,3 +77,34 @@ def set_seed(seed):
|
|||||||
random.seed(seed)
|
random.seed(seed)
|
||||||
np.random.seed(seed)
|
np.random.seed(seed)
|
||||||
torch.manual_seed(seed)
|
torch.manual_seed(seed)
|
||||||
|
|
||||||
|
|
||||||
|
def get_non_persistent_buffers_set(
|
||||||
|
module, memo: Optional[Set[nn.Module]] = None, prefix: str = "", remove_duplicate: bool = True
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
Args:
|
||||||
|
memo: a memo to store the set of modules already added to the result
|
||||||
|
prefix: a prefix that will be added to the name of the module
|
||||||
|
remove_duplicate: whether to remove the duplicated module instances in the result
|
||||||
|
or not
|
||||||
|
"""
|
||||||
|
|
||||||
|
if memo is None:
|
||||||
|
memo = set()
|
||||||
|
self_non_persistent_set = set()
|
||||||
|
if module not in memo:
|
||||||
|
if remove_duplicate:
|
||||||
|
memo.add(module)
|
||||||
|
self_non_persistent_set = set(
|
||||||
|
map(lambda key: prefix + ("." if prefix else "") + key, module._non_persistent_buffers_set)
|
||||||
|
)
|
||||||
|
for name, sub_module in module._modules.items():
|
||||||
|
if sub_module is None:
|
||||||
|
continue
|
||||||
|
submodule_prefix = prefix + ("." if prefix else "") + name
|
||||||
|
child_non_persistent_set = get_non_persistent_buffers_set(
|
||||||
|
sub_module, memo, submodule_prefix, remove_duplicate
|
||||||
|
)
|
||||||
|
self_non_persistent_set = set.union(self_non_persistent_set, child_non_persistent_set)
|
||||||
|
return self_non_persistent_set
|
||||||
|
@ -35,7 +35,7 @@ from colossalai.tensor.padded_tensor import (
|
|||||||
to_unpadded_tensor,
|
to_unpadded_tensor,
|
||||||
)
|
)
|
||||||
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
|
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
|
||||||
from colossalai.utils import _cast_float, free_storage, is_ddp_ignored
|
from colossalai.utils import _cast_float, free_storage, get_non_persistent_buffers_set, is_ddp_ignored
|
||||||
|
|
||||||
from .chunk import Chunk, ChunkManager, TensorState, init_chunk_manager
|
from .chunk import Chunk, ChunkManager, TensorState, init_chunk_manager
|
||||||
from .gemini_hook import GeminiZeROHook
|
from .gemini_hook import GeminiZeROHook
|
||||||
@ -187,7 +187,7 @@ class GeminiDDP(ModelWrapper):
|
|||||||
pin_memory=pin_memory,
|
pin_memory=pin_memory,
|
||||||
)
|
)
|
||||||
super().__init__(module)
|
super().__init__(module)
|
||||||
self._non_persistent_buffers_set = self._get_non_persistent_buffers_set(module)
|
self._non_persistent_buffers_set = get_non_persistent_buffers_set(module)
|
||||||
self._cast_buffers()
|
self._cast_buffers()
|
||||||
|
|
||||||
# register grad hook
|
# register grad hook
|
||||||
@ -257,36 +257,6 @@ class GeminiDDP(ModelWrapper):
|
|||||||
for p in params_to_ignore:
|
for p in params_to_ignore:
|
||||||
p._ddp_to_ignore = True
|
p._ddp_to_ignore = True
|
||||||
|
|
||||||
def _get_non_persistent_buffers_set(
|
|
||||||
self, module, memo: Optional[Set[nn.Module]] = None, prefix: str = "", remove_duplicate: bool = True
|
|
||||||
):
|
|
||||||
r"""
|
|
||||||
Args:
|
|
||||||
memo: a memo to store the set of modules already added to the result
|
|
||||||
prefix: a prefix that will be added to the name of the module
|
|
||||||
remove_duplicate: whether to remove the duplicated module instances in the result
|
|
||||||
or not
|
|
||||||
"""
|
|
||||||
|
|
||||||
if memo is None:
|
|
||||||
memo = set()
|
|
||||||
self_non_persistent_set = set()
|
|
||||||
if module not in memo:
|
|
||||||
if remove_duplicate:
|
|
||||||
memo.add(module)
|
|
||||||
self_non_persistent_set = set(
|
|
||||||
map(lambda key: prefix + ("." if prefix else "") + key, module._non_persistent_buffers_set)
|
|
||||||
)
|
|
||||||
for name, sub_module in module._modules.items():
|
|
||||||
if sub_module is None:
|
|
||||||
continue
|
|
||||||
submodule_prefix = prefix + ("." if prefix else "") + name
|
|
||||||
child_non_persistent_set = self._get_non_persistent_buffers_set(
|
|
||||||
sub_module, memo, submodule_prefix, remove_duplicate
|
|
||||||
)
|
|
||||||
self_non_persistent_set = set.union(self_non_persistent_set, child_non_persistent_set)
|
|
||||||
return self_non_persistent_set
|
|
||||||
|
|
||||||
def _post_forward(self):
|
def _post_forward(self):
|
||||||
"""This function is only triggered for inference."""
|
"""This function is only triggered for inference."""
|
||||||
access_list = list(self.chunk_manager.accessed_chunks)
|
access_list = list(self.chunk_manager.accessed_chunks)
|
||||||
|
Loading…
Reference in New Issue
Block a user