[misc] add verbose arg for zero and op builder (#3552)

* [misc] add print verbose

* [gemini] add print verbose

* [zero] add print verbose for low level

* [misc] add print verbose for op builder
This commit is contained in:
Hongxin Liu 2023-04-17 11:25:35 +08:00 committed by GitHub
parent 4341f5e8e6
commit 173dad0562
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 55 additions and 28 deletions

View File

@ -65,9 +65,9 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
class GeminiModel(ModelWrapper): class GeminiModel(ModelWrapper):
def __init__(self, module: nn.Module, gemini_config: dict) -> None: def __init__(self, module: nn.Module, gemini_config: dict, verbose: bool = False) -> None:
super().__init__(module) super().__init__(module)
self.module = zero_model_wrapper(module, zero_stage=3, gemini_config=gemini_config) self.module = zero_model_wrapper(module, zero_stage=3, gemini_config=gemini_config, verbose=verbose)
def unwrap(self): def unwrap(self):
# as save/load state dict is coupled with the GeminiDDP, we only return GeminiDDP model # as save/load state dict is coupled with the GeminiDDP, we only return GeminiDDP model
@ -76,8 +76,17 @@ class GeminiModel(ModelWrapper):
class GeminiOptimizer(OptimizerWrapper): class GeminiOptimizer(OptimizerWrapper):
def __init__(self, module: GeminiDDP, optimizer: Optimizer, zero_optim_config: dict, optim_kwargs: dict) -> None: def __init__(self,
optimizer = zero_optim_wrapper(module, optimizer, optim_config=zero_optim_config, **optim_kwargs) module: GeminiDDP,
optimizer: Optimizer,
zero_optim_config: dict,
optim_kwargs: dict,
verbose: bool = False) -> None:
optimizer = zero_optim_wrapper(module,
optimizer,
optim_config=zero_optim_config,
**optim_kwargs,
verbose=verbose)
super().__init__(optimizer) super().__init__(optimizer)
def backward(self, loss: Tensor, *args, **kwargs): def backward(self, loss: Tensor, *args, **kwargs):
@ -138,6 +147,7 @@ class GeminiPlugin(Plugin):
max_norm (float, optional): max_norm used for `clip_grad_norm`. You should notice that you shall not do max_norm (float, optional): max_norm used for `clip_grad_norm`. You should notice that you shall not do
clip_grad_norm by yourself when using ZeRO DDP. The ZeRO optimizer will take care of clip_grad_norm. clip_grad_norm by yourself when using ZeRO DDP. The ZeRO optimizer will take care of clip_grad_norm.
norm_type (float, optional): norm_type used for `clip_grad_norm`. norm_type (float, optional): norm_type used for `clip_grad_norm`.
verbose (bool, optional): verbose mode. Debug info including chunk search result will be printed. Defaults to False.
""" """
def __init__( def __init__(
@ -161,6 +171,7 @@ class GeminiPlugin(Plugin):
max_scale: float = 2**32, max_scale: float = 2**32,
max_norm: float = 0.0, max_norm: float = 0.0,
norm_type: float = 2.0, norm_type: float = 2.0,
verbose: bool = False,
) -> None: ) -> None:
assert dist.is_initialized( assert dist.is_initialized(
@ -188,6 +199,7 @@ class GeminiPlugin(Plugin):
max_scale=max_scale, max_scale=max_scale,
max_norm=max_norm, max_norm=max_norm,
norm_type=norm_type) norm_type=norm_type)
self.verbose = verbose
def support_no_sync(self) -> bool: def support_no_sync(self) -> bool:
return False return False
@ -275,10 +287,11 @@ class GeminiPlugin(Plugin):
# model = nn.SyncBatchNorm.convert_sync_batchnorm(model, None) # model = nn.SyncBatchNorm.convert_sync_batchnorm(model, None)
# wrap the model with Gemini # wrap the model with Gemini
model = GeminiModel(model, self.gemini_config) model = GeminiModel(model, self.gemini_config, self.verbose)
if not isinstance(optimizer, OptimizerWrapper): if not isinstance(optimizer, OptimizerWrapper):
optimizer = GeminiOptimizer(model.unwrap(), optimizer, self.zero_optim_config, self.optim_kwargs) optimizer = GeminiOptimizer(model.unwrap(), optimizer, self.zero_optim_config, self.optim_kwargs,
self.verbose)
return model, optimizer, criterion, dataloader, lr_scheduler return model, optimizer, criterion, dataloader, lr_scheduler

View File

@ -20,6 +20,7 @@ def safe_div(a, b):
def init_chunk_manager(model: nn.Module, def init_chunk_manager(model: nn.Module,
init_device: Optional[torch.device] = None, init_device: Optional[torch.device] = None,
hidden_dim: Optional[int] = None, hidden_dim: Optional[int] = None,
verbose: bool = False,
**kwargs) -> ChunkManager: **kwargs) -> ChunkManager:
if hidden_dim: if hidden_dim:
search_interval_byte = hidden_dim search_interval_byte = hidden_dim
@ -39,7 +40,7 @@ def init_chunk_manager(model: nn.Module,
total_size /= mb_size total_size /= mb_size
wasted_size /= mb_size wasted_size /= mb_size
if dist.get_rank() == 0: if verbose and dist.get_rank() == 0:
print("searching chunk configuration is completed in {:.2f} s.\n".format(span_s), print("searching chunk configuration is completed in {:.2f} s.\n".format(span_s),
"used number: {:.2f} MB, wasted number: {:.2f} MB\n".format(total_size, wasted_size), "used number: {:.2f} MB, wasted number: {:.2f} MB\n".format(total_size, wasted_size),
"total wasted percentage is {:.2f}%".format(100 * safe_div(wasted_size, total_size + wasted_size)), "total wasted percentage is {:.2f}%".format(100 * safe_div(wasted_size, total_size + wasted_size)),

View File

@ -567,7 +567,8 @@ class GeminiDDP(ZeroDDP):
search_range_mb: int = 32, search_range_mb: int = 32,
hidden_dim: Optional[int] = None, hidden_dim: Optional[int] = None,
min_chunk_size_mb: float = 32, min_chunk_size_mb: float = 32,
memstats: Optional[MemStats] = None) -> None: memstats: Optional[MemStats] = None,
verbose: bool = False) -> None:
""" """
A torch.Module warpper using ZeRO-DP and Genimi. A torch.Module warpper using ZeRO-DP and Genimi.
ZeRO is for parallel. Gemini is for memory management. ZeRO is for parallel. Gemini is for memory management.
@ -604,6 +605,7 @@ class GeminiDDP(ZeroDDP):
hidden_dim=hidden_dim, hidden_dim=hidden_dim,
search_range_mb=search_range_mb, search_range_mb=search_range_mb,
min_chunk_size_mb=min_chunk_size_mb, min_chunk_size_mb=min_chunk_size_mb,
strict_ddp_flag=strict_ddp_mode) strict_ddp_flag=strict_ddp_mode,
verbose=verbose)
gemini_manager = GeminiManager(placement_policy, chunk_manager, memstats) gemini_manager = GeminiManager(placement_policy, chunk_manager, memstats)
super().__init__(module, gemini_manager, pin_memory, force_outputs_fp32, strict_ddp_mode) super().__init__(module, gemini_manager, pin_memory, force_outputs_fp32, strict_ddp_mode)

View File

@ -54,6 +54,7 @@ class ZeroOptimizer(ColossalaiOptimizer):
clipping_norm (float, optional): The norm value used to clip gradient. Defaults to 0.0. clipping_norm (float, optional): The norm value used to clip gradient. Defaults to 0.0.
norm_type (float, optional): The type of norm used for gradient clipping. Currently, only L2-norm (norm_type=2.0) norm_type (float, optional): The type of norm used for gradient clipping. Currently, only L2-norm (norm_type=2.0)
is supported in ZeroOptimizer. Defaults to 2.0. is supported in ZeroOptimizer. Defaults to 2.0.
verbose (bool, optional): Whether to print verbose information, including grad overflow info. Defaults to False.
""" """
def __init__(self, def __init__(self,
@ -69,6 +70,7 @@ class ZeroOptimizer(ColossalaiOptimizer):
max_scale: float = 2**32, max_scale: float = 2**32,
clipping_norm: float = 0.0, clipping_norm: float = 0.0,
norm_type: float = 2.0, norm_type: float = 2.0,
verbose: bool = False,
**defaults: Any): **defaults: Any):
super().__init__(optim) super().__init__(optim)
assert isinstance(module, ZeroDDP) assert isinstance(module, ZeroDDP)
@ -83,6 +85,7 @@ class ZeroOptimizer(ColossalaiOptimizer):
self.chunk16_set: Set[Chunk] = set() self.chunk16_set: Set[Chunk] = set()
self.clipping_flag = clipping_norm > 0.0 self.clipping_flag = clipping_norm > 0.0
self.max_norm = clipping_norm self.max_norm = clipping_norm
self.verbose = verbose
if self.clipping_flag: if self.clipping_flag:
assert norm_type == 2.0, "ZeroOptimizer only supports L2 norm now" assert norm_type == 2.0, "ZeroOptimizer only supports L2 norm now"
@ -221,7 +224,8 @@ class ZeroOptimizer(ColossalaiOptimizer):
if found_inf: if found_inf:
self.optim_state = OptimState.UNSCALED # no need to unscale grad self.optim_state = OptimState.UNSCALED # no need to unscale grad
self.grad_scaler.update(found_inf) # update gradient scaler self.grad_scaler.update(found_inf) # update gradient scaler
self._logger.info(f'Found overflow. Skip step') if self.verbose:
self._logger.info(f'Found overflow. Skip step')
self._clear_global_norm() # clear recorded norm self._clear_global_norm() # clear recorded norm
self.zero_grad() # reset all gradients self.zero_grad() # reset all gradients
self._update_fp16_params() self._update_fp16_params()

View File

@ -440,6 +440,8 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
# update loss scale if overflow occurs # update loss scale if overflow occurs
if found_inf: if found_inf:
self._grad_store.reset_all_average_gradients() self._grad_store.reset_all_average_gradients()
if self._verbose:
self._logger.info(f'Found overflow. Skip step')
self.zero_grad() self.zero_grad()
return return

View File

@ -7,7 +7,10 @@ import torch.nn as nn
from .gemini import GeminiDDP from .gemini import GeminiDDP
def zero_model_wrapper(model: nn.Module, zero_stage: int = 1, gemini_config: Optional[Dict] = None): def zero_model_wrapper(model: nn.Module,
zero_stage: int = 1,
gemini_config: Optional[Dict] = None,
verbose: bool = False):
"""This wrapper function is used to wrap your training model for ZeRO DDP. """This wrapper function is used to wrap your training model for ZeRO DDP.
Example: Example:
@ -40,7 +43,7 @@ def zero_model_wrapper(model: nn.Module, zero_stage: int = 1, gemini_config: Opt
if zero_stage in [1, 2]: if zero_stage in [1, 2]:
wrapped_model = model wrapped_model = model
else: else:
wrapped_model = GeminiDDP(model, **gemini_config) wrapped_model = GeminiDDP(model, **gemini_config, verbose=verbose)
setattr(wrapped_model, "_colo_zero_stage", zero_stage) setattr(wrapped_model, "_colo_zero_stage", zero_stage)
@ -58,7 +61,8 @@ def zero_optim_wrapper(model: nn.Module,
max_scale: float = 2**32, max_scale: float = 2**32,
max_norm: float = 0.0, max_norm: float = 0.0,
norm_type: float = 2.0, norm_type: float = 2.0,
optim_config: Optional[Dict] = None): optim_config: Optional[Dict] = None,
verbose: bool = False):
"""This wrapper function is used to wrap your training optimizer for ZeRO DDP. """This wrapper function is used to wrap your training optimizer for ZeRO DDP.
Args: Args:
@ -79,6 +83,7 @@ def zero_optim_wrapper(model: nn.Module,
>>> zero2_config = dict(reduce_bucket_size=12 * 1024 * 1024, overlap_communication=True) >>> zero2_config = dict(reduce_bucket_size=12 * 1024 * 1024, overlap_communication=True)
>>> optim = zero_optim_wrapper(model, optim, optim_config=zero2_config) >>> optim = zero_optim_wrapper(model, optim, optim_config=zero2_config)
verbose (bool, optional): Whether to print the verbose info.
""" """
assert hasattr(model, "_colo_zero_stage"), "You should use `zero_ddp_wrapper` first" assert hasattr(model, "_colo_zero_stage"), "You should use `zero_ddp_wrapper` first"
zero_stage = getattr(model, "_colo_zero_stage") zero_stage = getattr(model, "_colo_zero_stage")
@ -102,8 +107,8 @@ def zero_optim_wrapper(model: nn.Module,
from colossalai.zero.low_level import LowLevelZeroOptimizer from colossalai.zero.low_level import LowLevelZeroOptimizer
config_dict['partition_grad'] = zero_stage == 2 config_dict['partition_grad'] = zero_stage == 2
config_dict['clip_grad_norm'] = max_norm config_dict['clip_grad_norm'] = max_norm
return LowLevelZeroOptimizer(optimizer, **config_dict) return LowLevelZeroOptimizer(optimizer, **config_dict, verbose=verbose)
else: else:
from colossalai.zero.gemini.gemini_optimizer import ZeroOptimizer from colossalai.zero.gemini.gemini_optimizer import ZeroOptimizer
config_dict['clipping_norm'] = max_norm config_dict['clipping_norm'] = max_norm
return ZeroOptimizer(optimizer, model, **config_dict) return ZeroOptimizer(optimizer, model, **config_dict, verbose=verbose)

View File

@ -7,7 +7,7 @@ import os
import time import time
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from pathlib import Path from pathlib import Path
from typing import List from typing import List, Optional
from .utils import check_cuda_availability, check_system_pytorch_cuda_match, print_rank_0 from .utils import check_cuda_availability, check_system_pytorch_cuda_match, print_rank_0
@ -138,7 +138,7 @@ class Builder(ABC):
# make sure system CUDA and pytorch CUDA match, an error will raised inside the function if not # make sure system CUDA and pytorch CUDA match, an error will raised inside the function if not
check_system_pytorch_cuda_match(CUDA_HOME) check_system_pytorch_cuda_match(CUDA_HOME)
def load(self, verbose=True): def load(self, verbose: Optional[bool] = None):
""" """
load the kernel during runtime. If the kernel is not built during pip install, it will build the kernel. load the kernel during runtime. If the kernel is not built during pip install, it will build the kernel.
If the kernel is built during runtime, it will be stored in `~/.cache/colossalai/torch_extensions/`. If the If the kernel is built during runtime, it will be stored in `~/.cache/colossalai/torch_extensions/`. If the
@ -149,6 +149,8 @@ class Builder(ABC):
Args: Args:
verbose (bool, optional): show detailed info. Defaults to True. verbose (bool, optional): show detailed info. Defaults to True.
""" """
if verbose is None:
verbose = os.environ.get('CAI_KERNEL_VERBOSE', '0') == '1'
# if the kernel has be compiled and cached, we directly use it # if the kernel has be compiled and cached, we directly use it
if self.cached_op_module is not None: if self.cached_op_module is not None:
return self.cached_op_module return self.cached_op_module

View File

@ -90,7 +90,6 @@ def check_system_pytorch_cuda_match(cuda_dir):
'Please make sure you have set the CUDA_HOME correctly and installed the correct PyTorch in https://pytorch.org/get-started/locally/ .' 'Please make sure you have set the CUDA_HOME correctly and installed the correct PyTorch in https://pytorch.org/get-started/locally/ .'
) )
print(bare_metal_minor != torch_cuda_minor)
if bare_metal_minor != torch_cuda_minor: if bare_metal_minor != torch_cuda_minor:
warnings.warn( warnings.warn(
f"[extension] The CUDA version on the system ({bare_metal_major}.{bare_metal_minor}) does not match with the version ({torch_cuda_major}.{torch_cuda_minor}) torch was compiled with. " f"[extension] The CUDA version on the system ({bare_metal_major}.{bare_metal_minor}) does not match with the version ({torch_cuda_major}.{torch_cuda_minor}) torch was compiled with. "
@ -156,16 +155,15 @@ def set_cuda_arch_list(cuda_dir):
# we only need to set this when CUDA is not available for cross-compilation # we only need to set this when CUDA is not available for cross-compilation
if not cuda_available: if not cuda_available:
warnings.warn( warnings.warn('\n[extension] PyTorch did not find available GPUs on this system.\n'
'\n[extension] PyTorch did not find available GPUs on this system.\n' 'If your intention is to cross-compile, this is not an error.\n'
'If your intention is to cross-compile, this is not an error.\n' 'By default, Colossal-AI will cross-compile for \n'
'By default, Colossal-AI will cross-compile for \n' '1. Pascal (compute capabilities 6.0, 6.1, 6.2),\n'
'1. Pascal (compute capabilities 6.0, 6.1, 6.2),\n' '2. Volta (compute capability 7.0)\n'
'2. Volta (compute capability 7.0)\n' '3. Turing (compute capability 7.5),\n'
'3. Turing (compute capability 7.5),\n' '4. Ampere (compute capability 8.0, 8.6)if the CUDA version is >= 11.0\n'
'4. Ampere (compute capability 8.0, 8.6)if the CUDA version is >= 11.0\n' '\nIf you wish to cross-compile for a single specific architecture,\n'
'\nIf you wish to cross-compile for a single specific architecture,\n' 'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.\n')
'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.\n')
if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None: if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None:
bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(cuda_dir) bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(cuda_dir)