mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-10 05:20:33 +00:00
[checkpointio] fix hybrid plugin model save (#6106)
This commit is contained in:
@@ -5,10 +5,11 @@ import os
|
||||
import random
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
from typing import Callable, Optional, Set
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
|
||||
@@ -76,3 +77,34 @@ def set_seed(seed):
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
|
||||
|
||||
def get_non_persistent_buffers_set(
|
||||
module, memo: Optional[Set[nn.Module]] = None, prefix: str = "", remove_duplicate: bool = True
|
||||
):
|
||||
r"""
|
||||
Args:
|
||||
memo: a memo to store the set of modules already added to the result
|
||||
prefix: a prefix that will be added to the name of the module
|
||||
remove_duplicate: whether to remove the duplicated module instances in the result
|
||||
or not
|
||||
"""
|
||||
|
||||
if memo is None:
|
||||
memo = set()
|
||||
self_non_persistent_set = set()
|
||||
if module not in memo:
|
||||
if remove_duplicate:
|
||||
memo.add(module)
|
||||
self_non_persistent_set = set(
|
||||
map(lambda key: prefix + ("." if prefix else "") + key, module._non_persistent_buffers_set)
|
||||
)
|
||||
for name, sub_module in module._modules.items():
|
||||
if sub_module is None:
|
||||
continue
|
||||
submodule_prefix = prefix + ("." if prefix else "") + name
|
||||
child_non_persistent_set = get_non_persistent_buffers_set(
|
||||
sub_module, memo, submodule_prefix, remove_duplicate
|
||||
)
|
||||
self_non_persistent_set = set.union(self_non_persistent_set, child_non_persistent_set)
|
||||
return self_non_persistent_set
|
||||
|
Reference in New Issue
Block a user