mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-15 06:00:07 +00:00
[lazy] refactor lazy init (#3891)
* [lazy] remove old lazy init * [lazy] refactor lazy init folder structure * [lazy] fix lazy tensor deepcopy * [test] update lazy init test
This commit is contained in:
@@ -2,13 +2,14 @@ import itertools
|
||||
from collections import OrderedDict
|
||||
from contextlib import nullcontext
|
||||
from functools import partial
|
||||
from typing import Dict, Iterator, List, Optional, Union, Tuple, Set
|
||||
from typing import Dict, Iterator, List, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.checkpoint_io.utils import calculate_tensor_size
|
||||
from colossalai.lazy import LazyTensor
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn.parallel.data_parallel import ColoDDP, _cast_float, free_storage
|
||||
from colossalai.tensor import ProcessGroup as ColoProcessGroup
|
||||
@@ -16,7 +17,6 @@ from colossalai.tensor import ReplicaSpec
|
||||
from colossalai.tensor.colo_parameter import ColoParameter, ColoTensor, ColoTensorSpec
|
||||
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
|
||||
from colossalai.utils import get_current_device, is_ddp_ignored
|
||||
from colossalai.utils.model.experimental import LazyTensor
|
||||
|
||||
from .chunk import Chunk, ChunkManager, TensorState, init_chunk_manager
|
||||
from .gemini_hook import GeminiZeROHook
|
||||
@@ -96,34 +96,38 @@ class ZeroDDP(ColoDDP):
|
||||
param_name = m_name + '.' + p_name if m_name else p_name
|
||||
self.name2param[param_name] = p_var
|
||||
super().__init__(module, process_group=ColoProcessGroup())
|
||||
self._non_persistent_buffers_set=self._get_non_persistent_buffers_set(module)
|
||||
self._non_persistent_buffers_set = self._get_non_persistent_buffers_set(module)
|
||||
self._cast_buffers()
|
||||
|
||||
def _get_non_persistent_buffers_set(self, module, memo: Optional[Set[nn.Module]] = None, prefix: str = '', remove_duplicate: bool = True):
|
||||
def _get_non_persistent_buffers_set(self,
|
||||
module,
|
||||
memo: Optional[Set[nn.Module]] = None,
|
||||
prefix: str = '',
|
||||
remove_duplicate: bool = True):
|
||||
r"""
|
||||
Args:
|
||||
memo: a memo to store the set of modules already added to the result
|
||||
prefix: a prefix that will be added to the name of the module
|
||||
remove_duplicate: whether to remove the duplicated module instances in the result
|
||||
or not
|
||||
"""
|
||||
|
||||
r"""
|
||||
Args:
|
||||
memo: a memo to store the set of modules already added to the result
|
||||
prefix: a prefix that will be added to the name of the module
|
||||
remove_duplicate: whether to remove the duplicated module instances in the result
|
||||
or not
|
||||
"""
|
||||
|
||||
if memo is None:
|
||||
memo = set()
|
||||
self_non_persistent_set = set()
|
||||
if module not in memo:
|
||||
if remove_duplicate:
|
||||
memo.add(module)
|
||||
self_non_persistent_set = set(map(lambda key: prefix + ('.' if prefix else '') + key, module._non_persistent_buffers_set))
|
||||
for name, sub_module in module._modules.items():
|
||||
if sub_module is None:
|
||||
continue
|
||||
submodule_prefix = prefix + ('.' if prefix else '') + name
|
||||
child_non_persistent_set = self._get_non_persistent_buffers_set(sub_module, memo, submodule_prefix, remove_duplicate)
|
||||
self_non_persistent_set = set.union(self_non_persistent_set, child_non_persistent_set)
|
||||
return self_non_persistent_set
|
||||
|
||||
if memo is None:
|
||||
memo = set()
|
||||
self_non_persistent_set = set()
|
||||
if module not in memo:
|
||||
if remove_duplicate:
|
||||
memo.add(module)
|
||||
self_non_persistent_set = set(
|
||||
map(lambda key: prefix + ('.' if prefix else '') + key, module._non_persistent_buffers_set))
|
||||
for name, sub_module in module._modules.items():
|
||||
if sub_module is None:
|
||||
continue
|
||||
submodule_prefix = prefix + ('.' if prefix else '') + name
|
||||
child_non_persistent_set = self._get_non_persistent_buffers_set(sub_module, memo, submodule_prefix,
|
||||
remove_duplicate)
|
||||
self_non_persistent_set = set.union(self_non_persistent_set, child_non_persistent_set)
|
||||
return self_non_persistent_set
|
||||
|
||||
def _post_forward(self):
|
||||
"""This function is only triggered for inference.
|
||||
|
Reference in New Issue
Block a user