From 152239bbfa6d5eca22633c9d73463ed8dcb300d7 Mon Sep 17 00:00:00 2001 From: Hongxin Liu Date: Wed, 12 Apr 2023 16:03:25 +0800 Subject: [PATCH] [gemini] gemini supports lazy init (#3379) * [gemini] fix nvme optimizer init * [gemini] gemini supports lazy init * [gemini] add init example * [gemini] add fool model * [zero] update gemini ddp * [zero] update init example * add chunk method * add chunk method * [lazyinit] fix lazy tensor tolist * [gemini] fix buffer materialization * [misc] remove useless file * [booster] update gemini plugin * [test] update gemini plugin test * [test] fix gemini plugin test * [gemini] fix import * [gemini] fix import * [lazyinit] use new metatensor * [lazyinit] use new metatensor * [lazyinit] fix __set__ method --- .../_analyzer/_subclasses/_monkey_patch.py | 3 +- colossalai/booster/plugin/gemini_plugin.py | 50 ------------------- colossalai/nn/optimizer/nvme_optimizer.py | 15 ++++-- colossalai/utils/model/experimental.py | 16 ++++-- colossalai/zero/gemini/chunk/search_utils.py | 9 ++-- colossalai/zero/gemini/gemini_ddp.py | 27 ++++++++-- .../test_plugin/test_gemini_plugin.py | 32 ++++++++++-- 7 files changed, 80 insertions(+), 72 deletions(-) diff --git a/colossalai/_analyzer/_subclasses/_monkey_patch.py b/colossalai/_analyzer/_subclasses/_monkey_patch.py index 7c1c3d3d8..b3ec98f08 100644 --- a/colossalai/_analyzer/_subclasses/_monkey_patch.py +++ b/colossalai/_analyzer/_subclasses/_monkey_patch.py @@ -2,8 +2,6 @@ import torch import torch.distributed as dist from packaging import version -aten = torch.ops.aten - __all__ = [ "_TorchFactoryMethod", "_TorchOverrideableFactoryMethod", @@ -51,6 +49,7 @@ _DistCommMethod = [ ] if version.parse(torch.__version__) >= version.parse('1.12.0'): + aten = torch.ops.aten # TODO: dive deep here # refer to https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/TensorShape.cpp _AliasATen = [ diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index 6693b1f44..659f36c21 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -16,10 +16,8 @@ from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO from colossalai.checkpoint_io.utils import save_state_dict from colossalai.cluster import DistCoordinator from colossalai.interface import ModelWrapper, OptimizerWrapper -from colossalai.tensor.colo_parameter import ColoParameter from colossalai.utils import get_current_device from colossalai.zero import GeminiDDP, zero_model_wrapper, zero_optim_wrapper -from colossalai.zero.gemini.colo_init_context import _convert_to_coloparam from colossalai.zero.gemini.memory_tracer import MemStats from .plugin_base import Plugin @@ -27,50 +25,6 @@ from .plugin_base import Plugin __all__ = ['GeminiPlugin'] -def convert_to_colo_param(module: nn.Module) -> None: - """Convert module's paramters to ColoParameter. This is a workaround and will be deprecated when lazy init is compatible with Gemini. - - Args: - module (nn.Module): Module to be converted. - """ - converted_modules = set() # handle shared modules - converted_params = dict() # record mapping between (torch.Tensor, ColoTensor) to distinguish the same reference - - def convert_recursively(m: nn.Module): - for child in m.children(): - if child not in converted_modules: - converted_modules.add(child) - convert_recursively(child) - - for name, p in m.named_parameters(recurse=False): - assert not isinstance(p, ColoParameter) - if p in converted_params: - target = converted_params[p] - else: - target = _convert_to_coloparam(p, p.device, p.dtype) - converted_params[p] = target - setattr(m, name, target) - target.shared_param_modules.append(m) - - convert_recursively(module) - - # optimizer should replace params in group as well. This attr should be deleted after replacing to avoid memory leak - module._converted_params = converted_params - - -def replace_param_in_group(optimizer: Optimizer, converted_params: dict) -> None: - """Replace param in optimizer's group with converted ColoParameter. - - Args: - optimizer (Optimizer): Optimizer to be replaced. - converted_params (dict): Mapping between (torch.Tensor, ColoTensor). - """ - for group in optimizer.param_groups: - for i, p in enumerate(group['params']): - if p in converted_params: - group['params'][i] = converted_params[p] - - class GeminiCheckpointIO(GeneralCheckpointIO): def __init__(self) -> None: @@ -113,8 +67,6 @@ class GeminiModel(ModelWrapper): def __init__(self, module: nn.Module, gemini_config: dict) -> None: super().__init__(module) - # TODO(ver217): only support Gemini now - convert_to_colo_param(module) self.module = zero_model_wrapper(module, zero_stage=3, gemini_config=gemini_config) def unwrap(self): @@ -125,8 +77,6 @@ class GeminiModel(ModelWrapper): class GeminiOptimizer(OptimizerWrapper): def __init__(self, module: GeminiDDP, optimizer: Optimizer, zero_optim_config: dict, optim_kwargs: dict) -> None: - replace_param_in_group(optimizer, module.module._converted_params) - del module.module._converted_params optimizer = zero_optim_wrapper(module, optimizer, optim_config=zero_optim_config, **optim_kwargs) super().__init__(optimizer) diff --git a/colossalai/nn/optimizer/nvme_optimizer.py b/colossalai/nn/optimizer/nvme_optimizer.py index cbb435a90..53e4a46c9 100644 --- a/colossalai/nn/optimizer/nvme_optimizer.py +++ b/colossalai/nn/optimizer/nvme_optimizer.py @@ -1,9 +1,10 @@ -import torch +import math import os import tempfile -import math +from typing import Callable, Dict, List, Optional + +import torch from torch.nn.parameter import Parameter -from typing import Optional, List, Dict, Callable class NVMeOptimizer(torch.optim.Optimizer): @@ -42,8 +43,9 @@ class NVMeOptimizer(torch.optim.Optimizer): self.offloader = None self.is_on_nvme: Dict[Parameter, bool] = {} self.offloaded_numel: int = 0 - self.total_numel: int = self._get_numel() - self.can_offload_numel = math.floor(self.total_numel * self.nvme_offload_fraction) + # As param may be not materialized here, these attributes are initalized when the first step + self.total_numel: Optional[int] = None + self.can_offload_numel: Optional[int] = None self.prefetch_params: List[Parameter] = [] self.param_to_prefetch_idx: Dict[Parameter, int] = {} @@ -77,6 +79,9 @@ class NVMeOptimizer(torch.optim.Optimizer): self.prefetch_params.append(p) def _pre_step(self, *state_keys: str) -> None: + if self.total_numel is None: + self.total_numel = self._get_numel() + self.can_offload_numel = math.floor(self.total_numel * self.nvme_offload_fraction) self._setup_prefetch_params() if self.offloader is None or len(self.prefetch_params) == 0: return diff --git a/colossalai/utils/model/experimental.py b/colossalai/utils/model/experimental.py index 6427a147a..c91751f1c 100644 --- a/colossalai/utils/model/experimental.py +++ b/colossalai/utils/model/experimental.py @@ -7,7 +7,7 @@ import torch.nn as nn from torch import Tensor from torch.utils._pytree import tree_map -from colossalai.fx.profiler.tensor import MetaTensor +from colossalai._analyzer._subclasses import MetaTensor from colossalai.tensor.d_tensor.d_tensor import DTensor from colossalai.tensor.d_tensor.layout import Layout @@ -37,7 +37,7 @@ _EARLY_MATERIALIZED_OPS = ['__getitem__', 'split'] # If your intent is to change the metadata of a Tensor (such as sizes / strides / storage / storage_offset) # without autograd tracking the change, remove the .data / .detach() call and wrap the change in a `with torch.no_grad():` block. # These ops cannot be unwrapped using .data -_CHANGE_META_OPS = ['_cudnn_rnn_flatten_weight', 'requires_grad_', '__get__'] +_CHANGE_META_OPS = ['_cudnn_rnn_flatten_weight', 'requires_grad_', '__get__', '__set__'] _LEGACY_TENSOR_CONSTRUCTOR = { 'FloatTensor': torch.float, @@ -75,6 +75,12 @@ class _MyTensor(Tensor): return super().__torch_function__(func, types, args, kwargs) +def _data_tolist(tensor: torch.Tensor) -> list: + """tolist() method is not allowed for a subclass of tensor. Tensor.data returns a Tensor. + """ + return tensor.data.tolist() + + def _convert_cls(tensor: 'LazyTensor', target: torch.Tensor) -> torch.Tensor: """Convert a lazy tensor's class to target's class, with target's data. @@ -94,7 +100,7 @@ def _convert_cls(tensor: 'LazyTensor', target: torch.Tensor) -> torch.Tensor: tensor.requires_grad = target.requires_grad # subclass of torch.Tensor does not have tolist() method # overwrite this method after materialization or distribution - tensor.tolist = MethodType(torch.Tensor.tolist, target) + tensor.tolist = MethodType(_data_tolist, tensor) return tensor @@ -144,7 +150,7 @@ class LazyTensor(torch.Tensor): if meta_data is None: device = kwargs.get('device', 'cpu') elem = func(*args, **{**kwargs, 'device': 'meta'}) - meta_data = MetaTensor(elem, fake_device=device) + meta_data = MetaTensor(elem, device=device) elem = meta_data._tensor # As a meta tensor cannot be modified __class__ to torch.Tensor, we should use an empty real tensor here r = torch.Tensor._make_subclass(cls, _EMPTY_DATA, require_grad=elem.requires_grad) @@ -255,7 +261,7 @@ class LazyTensor(torch.Tensor): tree_map(cls._replace_with_materialized, args) tree_map(cls._replace_with_materialized, kwargs) is_inplace: bool = (func.__name__.endswith('_') and not (func.__name__.endswith('__')) - or func.__name__ == "__setitem__") + or func.__name__ in ('__setitem__', '__set__')) is_change_meta_op: bool = func.__name__ in _CHANGE_META_OPS diff --git a/colossalai/zero/gemini/chunk/search_utils.py b/colossalai/zero/gemini/chunk/search_utils.py index a69b782ea..c4deec8fe 100644 --- a/colossalai/zero/gemini/chunk/search_utils.py +++ b/colossalai/zero/gemini/chunk/search_utils.py @@ -46,9 +46,10 @@ def _get_unused_byte(size_list: List[int], chunk_size: int) -> int: def _tensor_numel(local_param: ColoParameter, strict_ddp_flag: bool): - if strict_ddp_flag: + if strict_ddp_flag and type(local_param) is ColoParameter: return local_param.numel_global() else: + # if local_param is not ColoParameter, we assume it's replicated return local_param.numel() @@ -67,11 +68,13 @@ def classify_params_by_dp_degree(param_order: OrderedParamGenerator, """ params_dict: Dict[int, List[ColoParameter]] = dict() for param in param_order.generate(): - assert isinstance(param, ColoParameter), "please init model in the ColoInitContext" + # assert isinstance(param, ColoParameter), "please init model in the ColoInitContext" if is_ddp_ignored(param): continue - if strict_ddp_flag: + if strict_ddp_flag or type(param) is not ColoParameter: + # if model is not initialized with ColoInitContext, we assume it's replicated + # TODO(ver217): integrate DTensor param_key = dist.get_world_size() else: param_key = param.process_group.dp_world_size() diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py index 50f1b1ef1..c06239dfa 100644 --- a/colossalai/zero/gemini/gemini_ddp.py +++ b/colossalai/zero/gemini/gemini_ddp.py @@ -1,7 +1,7 @@ import itertools from collections import OrderedDict from functools import partial -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Union import torch import torch.distributed as dist @@ -14,6 +14,7 @@ from colossalai.tensor import ReplicaSpec from colossalai.tensor.colo_parameter import ColoParameter, ColoTensor, ColoTensorSpec from colossalai.tensor.param_op_hook import ColoParamOpHookManager from colossalai.utils import get_current_device, is_ddp_ignored +from colossalai.utils.model.experimental import LazyTensor from .chunk import Chunk, ChunkManager, TensorState, init_chunk_manager from .gemini_hook import GeminiZeROHook @@ -55,7 +56,6 @@ class ZeroDDP(ColoDDP): pin_memory: bool = False, force_outputs_fp32: bool = False, strict_ddp_mode: bool = False) -> None: - super().__init__(module, process_group=ColoProcessGroup()) self.gemini_manager = gemini_manager self.chunk_manager: ChunkManager = gemini_manager.chunk_manager self.force_outputs_fp32 = force_outputs_fp32 @@ -67,7 +67,6 @@ class ZeroDDP(ColoDDP): self.param2name: Dict[nn.Parameter, str] = dict() self.name2param: Dict[str, nn.Parameter] = dict() - self._cast_buffers() self._logger = get_dist_logger() if self.gemini_manager._premade_memstats_: @@ -91,6 +90,8 @@ class ZeroDDP(ColoDDP): for p_name, p_var in m_var.named_parameters(recurse=False): param_name = m_name + '.' + p_name if m_name else p_name self.name2param[param_name] = p_var + super().__init__(module, process_group=ColoProcessGroup()) + self._cast_buffers() def _post_forward(self): """This function is only triggered for inference. @@ -478,7 +479,8 @@ class ZeroDDP(ColoDDP): def _init_chunks(self, param_order, strict_ddp_mode: bool, cpu_offload: bool, pin_memory: bool): ddp_pg = ColoProcessGroup() for p in param_order.generate(): - assert isinstance(p, ColoParameter) + self._preprocess_param(p) + assert type(p) is ColoParameter # gather sharded parameters in the strict ddp mode if strict_ddp_mode: @@ -531,10 +533,27 @@ class ZeroDDP(ColoDDP): def _cast_buffers(self): for buffer in self.module.buffers(): + if isinstance(buffer, LazyTensor): + buffer.materialize() buffer.data = buffer.cuda() if torch.is_floating_point(buffer): buffer.data = buffer.half() + def _preprocess_param(self, p: Union[nn.Parameter, ColoParameter, 'LazyTensor']) -> None: + """Convert parameter to ColoParameter in-place. + Args: + p (Union[nn.Parameter, ColoParameter, LazyTensor]): parameter to be converted + """ + if type(p) is ColoParameter: + # model is initialized with ColoInitContext + return + requires_grad = p.requires_grad + if isinstance(p, LazyTensor): + # model is initialized with LazyInitContext + p.materialize() + p.__class__ = ColoParameter + p.__init__(p, requires_grad=requires_grad) + class GeminiDDP(ZeroDDP): diff --git a/tests/test_booster/test_plugin/test_gemini_plugin.py b/tests/test_booster/test_plugin/test_gemini_plugin.py index a3c63fd09..d804c727a 100644 --- a/tests/test_booster/test_plugin/test_gemini_plugin.py +++ b/tests/test_booster/test_plugin/test_gemini_plugin.py @@ -1,21 +1,31 @@ +from contextlib import nullcontext + import torch import torch.distributed as dist import colossalai from colossalai.booster import Booster from colossalai.booster.plugin import GeminiPlugin +from colossalai.fx import is_compatible_with_meta from colossalai.nn.optimizer import HybridAdam from colossalai.tensor.colo_parameter import ColoParameter -from colossalai.testing import rerun_if_address_is_in_use, spawn +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from colossalai.zero import ColoInitContext from tests.kit.model_zoo import model_zoo -def check_gemini_plugin(early_stop: bool = True): +@parameterize('init_method', ['lazy', 'none', 'colo']) +def check_gemini_plugin(init_method: str = 'none', early_stop: bool = True): """check gemini plugin over model zoo Args: early_stop (bool, optional): Whether to stop when getting the first error. Defaults to True. """ + is_support_meta = is_compatible_with_meta() + if not is_support_meta and init_method == 'lazy': + return + + from colossalai.utils.model.experimental import LazyInitContext passed_models = [] failed_info = {} # (model_name, error) pair @@ -40,10 +50,25 @@ def check_gemini_plugin(early_stop: bool = True): ]: continue + if init_method == 'lazy' and name in [ + 'timm_convmixer', 'timm_vision_transformer', 'timm_deit', 'timm_deit3', 'timm_inception_v3', + 'timm_tnt_b_patch16_224', 'timm_rexnet', 'torchvision_densenet121', 'torchvision_efficientnet_b0', + 'torchvision_mobilenet_v2', 'torchvision_mnasnet0_5', 'torchvision_regnet_x_16gf', + 'torchvision_shufflenet_v2_x0_5', 'torchvision_efficientnet_v2_s' + ]: + continue + try: + if init_method == 'colo': + ctx = ColoInitContext() + elif init_method == 'lazy': + ctx = LazyInitContext() + else: + ctx = nullcontext() plugin = GeminiPlugin(placement_policy='cuda', strict_ddp_mode=True, max_norm=1.0, initial_scale=2**5) booster = Booster(plugin=plugin) - model = model_fn() + with ctx: + model = model_fn() optimizer = HybridAdam(model.parameters(), lr=1e-3) criterion = lambda x: x.mean() data = data_gen_fn() @@ -76,6 +101,7 @@ def check_gemini_plugin(early_stop: bool = True): torch.cuda.empty_cache() if dist.get_rank() == 0: + print(f'Init method: {init_method}') print(f'Passed models({len(passed_models)}): {passed_models}\n\n') print(f'Failed models({len(failed_info)}): {list(failed_info.keys())}\n\n') assert len(failed_info) == 0, '\n'.join([f'{k}: {v}' for k, v in failed_info.items()])