From 173dad05628489f50bd131892b0b3cfbc40b4eb4 Mon Sep 17 00:00:00 2001 From: Hongxin Liu Date: Mon, 17 Apr 2023 11:25:35 +0800 Subject: [PATCH] [misc] add verbose arg for zero and op builder (#3552) * [misc] add print verbose * [gemini] add print verbose * [zero] add print verbose for low level * [misc] add print verbose for op builder --- colossalai/booster/plugin/gemini_plugin.py | 25 +++++++++++++++----- colossalai/zero/gemini/chunk/utils.py | 3 ++- colossalai/zero/gemini/gemini_ddp.py | 6 +++-- colossalai/zero/gemini/gemini_optimizer.py | 6 ++++- colossalai/zero/low_level/low_level_optim.py | 2 ++ colossalai/zero/wrapper.py | 15 ++++++++---- op_builder/builder.py | 6 +++-- op_builder/utils.py | 20 +++++++--------- 8 files changed, 55 insertions(+), 28 deletions(-) diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index 659f36c21..deda00d8a 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -65,9 +65,9 @@ class GeminiCheckpointIO(GeneralCheckpointIO): class GeminiModel(ModelWrapper): - def __init__(self, module: nn.Module, gemini_config: dict) -> None: + def __init__(self, module: nn.Module, gemini_config: dict, verbose: bool = False) -> None: super().__init__(module) - self.module = zero_model_wrapper(module, zero_stage=3, gemini_config=gemini_config) + self.module = zero_model_wrapper(module, zero_stage=3, gemini_config=gemini_config, verbose=verbose) def unwrap(self): # as save/load state dict is coupled with the GeminiDDP, we only return GeminiDDP model @@ -76,8 +76,17 @@ class GeminiModel(ModelWrapper): class GeminiOptimizer(OptimizerWrapper): - def __init__(self, module: GeminiDDP, optimizer: Optimizer, zero_optim_config: dict, optim_kwargs: dict) -> None: - optimizer = zero_optim_wrapper(module, optimizer, optim_config=zero_optim_config, **optim_kwargs) + def __init__(self, + module: GeminiDDP, + optimizer: Optimizer, + zero_optim_config: dict, + optim_kwargs: dict, + verbose: bool = False) -> None: + optimizer = zero_optim_wrapper(module, + optimizer, + optim_config=zero_optim_config, + **optim_kwargs, + verbose=verbose) super().__init__(optimizer) def backward(self, loss: Tensor, *args, **kwargs): @@ -138,6 +147,7 @@ class GeminiPlugin(Plugin): max_norm (float, optional): max_norm used for `clip_grad_norm`. You should notice that you shall not do clip_grad_norm by yourself when using ZeRO DDP. The ZeRO optimizer will take care of clip_grad_norm. norm_type (float, optional): norm_type used for `clip_grad_norm`. + verbose (bool, optional): verbose mode. Debug info including chunk search result will be printed. Defaults to False. """ def __init__( @@ -161,6 +171,7 @@ class GeminiPlugin(Plugin): max_scale: float = 2**32, max_norm: float = 0.0, norm_type: float = 2.0, + verbose: bool = False, ) -> None: assert dist.is_initialized( @@ -188,6 +199,7 @@ class GeminiPlugin(Plugin): max_scale=max_scale, max_norm=max_norm, norm_type=norm_type) + self.verbose = verbose def support_no_sync(self) -> bool: return False @@ -275,10 +287,11 @@ class GeminiPlugin(Plugin): # model = nn.SyncBatchNorm.convert_sync_batchnorm(model, None) # wrap the model with Gemini - model = GeminiModel(model, self.gemini_config) + model = GeminiModel(model, self.gemini_config, self.verbose) if not isinstance(optimizer, OptimizerWrapper): - optimizer = GeminiOptimizer(model.unwrap(), optimizer, self.zero_optim_config, self.optim_kwargs) + optimizer = GeminiOptimizer(model.unwrap(), optimizer, self.zero_optim_config, self.optim_kwargs, + self.verbose) return model, optimizer, criterion, dataloader, lr_scheduler diff --git a/colossalai/zero/gemini/chunk/utils.py b/colossalai/zero/gemini/chunk/utils.py index 283f74203..71242dcd6 100644 --- a/colossalai/zero/gemini/chunk/utils.py +++ b/colossalai/zero/gemini/chunk/utils.py @@ -20,6 +20,7 @@ def safe_div(a, b): def init_chunk_manager(model: nn.Module, init_device: Optional[torch.device] = None, hidden_dim: Optional[int] = None, + verbose: bool = False, **kwargs) -> ChunkManager: if hidden_dim: search_interval_byte = hidden_dim @@ -39,7 +40,7 @@ def init_chunk_manager(model: nn.Module, total_size /= mb_size wasted_size /= mb_size - if dist.get_rank() == 0: + if verbose and dist.get_rank() == 0: print("searching chunk configuration is completed in {:.2f} s.\n".format(span_s), "used number: {:.2f} MB, wasted number: {:.2f} MB\n".format(total_size, wasted_size), "total wasted percentage is {:.2f}%".format(100 * safe_div(wasted_size, total_size + wasted_size)), diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py index c06239dfa..2e35be066 100644 --- a/colossalai/zero/gemini/gemini_ddp.py +++ b/colossalai/zero/gemini/gemini_ddp.py @@ -567,7 +567,8 @@ class GeminiDDP(ZeroDDP): search_range_mb: int = 32, hidden_dim: Optional[int] = None, min_chunk_size_mb: float = 32, - memstats: Optional[MemStats] = None) -> None: + memstats: Optional[MemStats] = None, + verbose: bool = False) -> None: """ A torch.Module warpper using ZeRO-DP and Genimi. ZeRO is for parallel. Gemini is for memory management. @@ -604,6 +605,7 @@ class GeminiDDP(ZeroDDP): hidden_dim=hidden_dim, search_range_mb=search_range_mb, min_chunk_size_mb=min_chunk_size_mb, - strict_ddp_flag=strict_ddp_mode) + strict_ddp_flag=strict_ddp_mode, + verbose=verbose) gemini_manager = GeminiManager(placement_policy, chunk_manager, memstats) super().__init__(module, gemini_manager, pin_memory, force_outputs_fp32, strict_ddp_mode) diff --git a/colossalai/zero/gemini/gemini_optimizer.py b/colossalai/zero/gemini/gemini_optimizer.py index 8940ab9a3..71c4f65cb 100644 --- a/colossalai/zero/gemini/gemini_optimizer.py +++ b/colossalai/zero/gemini/gemini_optimizer.py @@ -54,6 +54,7 @@ class ZeroOptimizer(ColossalaiOptimizer): clipping_norm (float, optional): The norm value used to clip gradient. Defaults to 0.0. norm_type (float, optional): The type of norm used for gradient clipping. Currently, only L2-norm (norm_type=2.0) is supported in ZeroOptimizer. Defaults to 2.0. + verbose (bool, optional): Whether to print verbose information, including grad overflow info. Defaults to False. """ def __init__(self, @@ -69,6 +70,7 @@ class ZeroOptimizer(ColossalaiOptimizer): max_scale: float = 2**32, clipping_norm: float = 0.0, norm_type: float = 2.0, + verbose: bool = False, **defaults: Any): super().__init__(optim) assert isinstance(module, ZeroDDP) @@ -83,6 +85,7 @@ class ZeroOptimizer(ColossalaiOptimizer): self.chunk16_set: Set[Chunk] = set() self.clipping_flag = clipping_norm > 0.0 self.max_norm = clipping_norm + self.verbose = verbose if self.clipping_flag: assert norm_type == 2.0, "ZeroOptimizer only supports L2 norm now" @@ -221,7 +224,8 @@ class ZeroOptimizer(ColossalaiOptimizer): if found_inf: self.optim_state = OptimState.UNSCALED # no need to unscale grad self.grad_scaler.update(found_inf) # update gradient scaler - self._logger.info(f'Found overflow. Skip step') + if self.verbose: + self._logger.info(f'Found overflow. Skip step') self._clear_global_norm() # clear recorded norm self.zero_grad() # reset all gradients self._update_fp16_params() diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index 49fb8b54b..39ade27b9 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -440,6 +440,8 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer): # update loss scale if overflow occurs if found_inf: self._grad_store.reset_all_average_gradients() + if self._verbose: + self._logger.info(f'Found overflow. Skip step') self.zero_grad() return diff --git a/colossalai/zero/wrapper.py b/colossalai/zero/wrapper.py index 4553249e2..6cdb8fc59 100644 --- a/colossalai/zero/wrapper.py +++ b/colossalai/zero/wrapper.py @@ -7,7 +7,10 @@ import torch.nn as nn from .gemini import GeminiDDP -def zero_model_wrapper(model: nn.Module, zero_stage: int = 1, gemini_config: Optional[Dict] = None): +def zero_model_wrapper(model: nn.Module, + zero_stage: int = 1, + gemini_config: Optional[Dict] = None, + verbose: bool = False): """This wrapper function is used to wrap your training model for ZeRO DDP. Example: @@ -40,7 +43,7 @@ def zero_model_wrapper(model: nn.Module, zero_stage: int = 1, gemini_config: Opt if zero_stage in [1, 2]: wrapped_model = model else: - wrapped_model = GeminiDDP(model, **gemini_config) + wrapped_model = GeminiDDP(model, **gemini_config, verbose=verbose) setattr(wrapped_model, "_colo_zero_stage", zero_stage) @@ -58,7 +61,8 @@ def zero_optim_wrapper(model: nn.Module, max_scale: float = 2**32, max_norm: float = 0.0, norm_type: float = 2.0, - optim_config: Optional[Dict] = None): + optim_config: Optional[Dict] = None, + verbose: bool = False): """This wrapper function is used to wrap your training optimizer for ZeRO DDP. Args: @@ -79,6 +83,7 @@ def zero_optim_wrapper(model: nn.Module, >>> zero2_config = dict(reduce_bucket_size=12 * 1024 * 1024, overlap_communication=True) >>> optim = zero_optim_wrapper(model, optim, optim_config=zero2_config) + verbose (bool, optional): Whether to print the verbose info. """ assert hasattr(model, "_colo_zero_stage"), "You should use `zero_ddp_wrapper` first" zero_stage = getattr(model, "_colo_zero_stage") @@ -102,8 +107,8 @@ def zero_optim_wrapper(model: nn.Module, from colossalai.zero.low_level import LowLevelZeroOptimizer config_dict['partition_grad'] = zero_stage == 2 config_dict['clip_grad_norm'] = max_norm - return LowLevelZeroOptimizer(optimizer, **config_dict) + return LowLevelZeroOptimizer(optimizer, **config_dict, verbose=verbose) else: from colossalai.zero.gemini.gemini_optimizer import ZeroOptimizer config_dict['clipping_norm'] = max_norm - return ZeroOptimizer(optimizer, model, **config_dict) + return ZeroOptimizer(optimizer, model, **config_dict, verbose=verbose) diff --git a/op_builder/builder.py b/op_builder/builder.py index b9f44decc..16bf173ff 100644 --- a/op_builder/builder.py +++ b/op_builder/builder.py @@ -7,7 +7,7 @@ import os import time from abc import ABC, abstractmethod from pathlib import Path -from typing import List +from typing import List, Optional from .utils import check_cuda_availability, check_system_pytorch_cuda_match, print_rank_0 @@ -138,7 +138,7 @@ class Builder(ABC): # make sure system CUDA and pytorch CUDA match, an error will raised inside the function if not check_system_pytorch_cuda_match(CUDA_HOME) - def load(self, verbose=True): + def load(self, verbose: Optional[bool] = None): """ load the kernel during runtime. If the kernel is not built during pip install, it will build the kernel. If the kernel is built during runtime, it will be stored in `~/.cache/colossalai/torch_extensions/`. If the @@ -149,6 +149,8 @@ class Builder(ABC): Args: verbose (bool, optional): show detailed info. Defaults to True. """ + if verbose is None: + verbose = os.environ.get('CAI_KERNEL_VERBOSE', '0') == '1' # if the kernel has be compiled and cached, we directly use it if self.cached_op_module is not None: return self.cached_op_module diff --git a/op_builder/utils.py b/op_builder/utils.py index 4029703e4..1b1bd5f49 100644 --- a/op_builder/utils.py +++ b/op_builder/utils.py @@ -90,7 +90,6 @@ def check_system_pytorch_cuda_match(cuda_dir): 'Please make sure you have set the CUDA_HOME correctly and installed the correct PyTorch in https://pytorch.org/get-started/locally/ .' ) - print(bare_metal_minor != torch_cuda_minor) if bare_metal_minor != torch_cuda_minor: warnings.warn( f"[extension] The CUDA version on the system ({bare_metal_major}.{bare_metal_minor}) does not match with the version ({torch_cuda_major}.{torch_cuda_minor}) torch was compiled with. " @@ -156,16 +155,15 @@ def set_cuda_arch_list(cuda_dir): # we only need to set this when CUDA is not available for cross-compilation if not cuda_available: - warnings.warn( - '\n[extension] PyTorch did not find available GPUs on this system.\n' - 'If your intention is to cross-compile, this is not an error.\n' - 'By default, Colossal-AI will cross-compile for \n' - '1. Pascal (compute capabilities 6.0, 6.1, 6.2),\n' - '2. Volta (compute capability 7.0)\n' - '3. Turing (compute capability 7.5),\n' - '4. Ampere (compute capability 8.0, 8.6)if the CUDA version is >= 11.0\n' - '\nIf you wish to cross-compile for a single specific architecture,\n' - 'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.\n') + warnings.warn('\n[extension] PyTorch did not find available GPUs on this system.\n' + 'If your intention is to cross-compile, this is not an error.\n' + 'By default, Colossal-AI will cross-compile for \n' + '1. Pascal (compute capabilities 6.0, 6.1, 6.2),\n' + '2. Volta (compute capability 7.0)\n' + '3. Turing (compute capability 7.5),\n' + '4. Ampere (compute capability 8.0, 8.6)if the CUDA version is >= 11.0\n' + '\nIf you wish to cross-compile for a single specific architecture,\n' + 'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.\n') if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None: bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(cuda_dir)