[checkpointio] fix hybrid plugin model save (#6106)

This commit is contained in:
Hongxin Liu
2024-10-31 17:04:53 +08:00
committed by GitHub
parent 89a9a600bc
commit c2e8f61592
4 changed files with 41 additions and 38 deletions

View File

@@ -5,10 +5,11 @@ import os
import random
from contextlib import contextmanager
from pathlib import Path
from typing import Callable
from typing import Callable, Optional, Set
import numpy as np
import torch
import torch.nn as nn
from colossalai.accelerator import get_accelerator
@@ -76,3 +77,34 @@ def set_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
def get_non_persistent_buffers_set(
module, memo: Optional[Set[nn.Module]] = None, prefix: str = "", remove_duplicate: bool = True
):
r"""
Args:
memo: a memo to store the set of modules already added to the result
prefix: a prefix that will be added to the name of the module
remove_duplicate: whether to remove the duplicated module instances in the result
or not
"""
if memo is None:
memo = set()
self_non_persistent_set = set()
if module not in memo:
if remove_duplicate:
memo.add(module)
self_non_persistent_set = set(
map(lambda key: prefix + ("." if prefix else "") + key, module._non_persistent_buffers_set)
)
for name, sub_module in module._modules.items():
if sub_module is None:
continue
submodule_prefix = prefix + ("." if prefix else "") + name
child_non_persistent_set = get_non_persistent_buffers_set(
sub_module, memo, submodule_prefix, remove_duplicate
)
self_non_persistent_set = set.union(self_non_persistent_set, child_non_persistent_set)
return self_non_persistent_set