mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-25 15:01:43 +00:00
[misc] add verbose arg for zero and op builder (#3552)
* [misc] add print verbose * [gemini] add print verbose * [zero] add print verbose for low level * [misc] add print verbose for op builder
This commit is contained in:
parent
4341f5e8e6
commit
173dad0562
@ -65,9 +65,9 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
|
|||||||
|
|
||||||
class GeminiModel(ModelWrapper):
|
class GeminiModel(ModelWrapper):
|
||||||
|
|
||||||
def __init__(self, module: nn.Module, gemini_config: dict) -> None:
|
def __init__(self, module: nn.Module, gemini_config: dict, verbose: bool = False) -> None:
|
||||||
super().__init__(module)
|
super().__init__(module)
|
||||||
self.module = zero_model_wrapper(module, zero_stage=3, gemini_config=gemini_config)
|
self.module = zero_model_wrapper(module, zero_stage=3, gemini_config=gemini_config, verbose=verbose)
|
||||||
|
|
||||||
def unwrap(self):
|
def unwrap(self):
|
||||||
# as save/load state dict is coupled with the GeminiDDP, we only return GeminiDDP model
|
# as save/load state dict is coupled with the GeminiDDP, we only return GeminiDDP model
|
||||||
@ -76,8 +76,17 @@ class GeminiModel(ModelWrapper):
|
|||||||
|
|
||||||
class GeminiOptimizer(OptimizerWrapper):
|
class GeminiOptimizer(OptimizerWrapper):
|
||||||
|
|
||||||
def __init__(self, module: GeminiDDP, optimizer: Optimizer, zero_optim_config: dict, optim_kwargs: dict) -> None:
|
def __init__(self,
|
||||||
optimizer = zero_optim_wrapper(module, optimizer, optim_config=zero_optim_config, **optim_kwargs)
|
module: GeminiDDP,
|
||||||
|
optimizer: Optimizer,
|
||||||
|
zero_optim_config: dict,
|
||||||
|
optim_kwargs: dict,
|
||||||
|
verbose: bool = False) -> None:
|
||||||
|
optimizer = zero_optim_wrapper(module,
|
||||||
|
optimizer,
|
||||||
|
optim_config=zero_optim_config,
|
||||||
|
**optim_kwargs,
|
||||||
|
verbose=verbose)
|
||||||
super().__init__(optimizer)
|
super().__init__(optimizer)
|
||||||
|
|
||||||
def backward(self, loss: Tensor, *args, **kwargs):
|
def backward(self, loss: Tensor, *args, **kwargs):
|
||||||
@ -138,6 +147,7 @@ class GeminiPlugin(Plugin):
|
|||||||
max_norm (float, optional): max_norm used for `clip_grad_norm`. You should notice that you shall not do
|
max_norm (float, optional): max_norm used for `clip_grad_norm`. You should notice that you shall not do
|
||||||
clip_grad_norm by yourself when using ZeRO DDP. The ZeRO optimizer will take care of clip_grad_norm.
|
clip_grad_norm by yourself when using ZeRO DDP. The ZeRO optimizer will take care of clip_grad_norm.
|
||||||
norm_type (float, optional): norm_type used for `clip_grad_norm`.
|
norm_type (float, optional): norm_type used for `clip_grad_norm`.
|
||||||
|
verbose (bool, optional): verbose mode. Debug info including chunk search result will be printed. Defaults to False.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -161,6 +171,7 @@ class GeminiPlugin(Plugin):
|
|||||||
max_scale: float = 2**32,
|
max_scale: float = 2**32,
|
||||||
max_norm: float = 0.0,
|
max_norm: float = 0.0,
|
||||||
norm_type: float = 2.0,
|
norm_type: float = 2.0,
|
||||||
|
verbose: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
||||||
assert dist.is_initialized(
|
assert dist.is_initialized(
|
||||||
@ -188,6 +199,7 @@ class GeminiPlugin(Plugin):
|
|||||||
max_scale=max_scale,
|
max_scale=max_scale,
|
||||||
max_norm=max_norm,
|
max_norm=max_norm,
|
||||||
norm_type=norm_type)
|
norm_type=norm_type)
|
||||||
|
self.verbose = verbose
|
||||||
|
|
||||||
def support_no_sync(self) -> bool:
|
def support_no_sync(self) -> bool:
|
||||||
return False
|
return False
|
||||||
@ -275,10 +287,11 @@ class GeminiPlugin(Plugin):
|
|||||||
# model = nn.SyncBatchNorm.convert_sync_batchnorm(model, None)
|
# model = nn.SyncBatchNorm.convert_sync_batchnorm(model, None)
|
||||||
|
|
||||||
# wrap the model with Gemini
|
# wrap the model with Gemini
|
||||||
model = GeminiModel(model, self.gemini_config)
|
model = GeminiModel(model, self.gemini_config, self.verbose)
|
||||||
|
|
||||||
if not isinstance(optimizer, OptimizerWrapper):
|
if not isinstance(optimizer, OptimizerWrapper):
|
||||||
optimizer = GeminiOptimizer(model.unwrap(), optimizer, self.zero_optim_config, self.optim_kwargs)
|
optimizer = GeminiOptimizer(model.unwrap(), optimizer, self.zero_optim_config, self.optim_kwargs,
|
||||||
|
self.verbose)
|
||||||
|
|
||||||
return model, optimizer, criterion, dataloader, lr_scheduler
|
return model, optimizer, criterion, dataloader, lr_scheduler
|
||||||
|
|
||||||
|
@ -20,6 +20,7 @@ def safe_div(a, b):
|
|||||||
def init_chunk_manager(model: nn.Module,
|
def init_chunk_manager(model: nn.Module,
|
||||||
init_device: Optional[torch.device] = None,
|
init_device: Optional[torch.device] = None,
|
||||||
hidden_dim: Optional[int] = None,
|
hidden_dim: Optional[int] = None,
|
||||||
|
verbose: bool = False,
|
||||||
**kwargs) -> ChunkManager:
|
**kwargs) -> ChunkManager:
|
||||||
if hidden_dim:
|
if hidden_dim:
|
||||||
search_interval_byte = hidden_dim
|
search_interval_byte = hidden_dim
|
||||||
@ -39,7 +40,7 @@ def init_chunk_manager(model: nn.Module,
|
|||||||
total_size /= mb_size
|
total_size /= mb_size
|
||||||
wasted_size /= mb_size
|
wasted_size /= mb_size
|
||||||
|
|
||||||
if dist.get_rank() == 0:
|
if verbose and dist.get_rank() == 0:
|
||||||
print("searching chunk configuration is completed in {:.2f} s.\n".format(span_s),
|
print("searching chunk configuration is completed in {:.2f} s.\n".format(span_s),
|
||||||
"used number: {:.2f} MB, wasted number: {:.2f} MB\n".format(total_size, wasted_size),
|
"used number: {:.2f} MB, wasted number: {:.2f} MB\n".format(total_size, wasted_size),
|
||||||
"total wasted percentage is {:.2f}%".format(100 * safe_div(wasted_size, total_size + wasted_size)),
|
"total wasted percentage is {:.2f}%".format(100 * safe_div(wasted_size, total_size + wasted_size)),
|
||||||
|
@ -567,7 +567,8 @@ class GeminiDDP(ZeroDDP):
|
|||||||
search_range_mb: int = 32,
|
search_range_mb: int = 32,
|
||||||
hidden_dim: Optional[int] = None,
|
hidden_dim: Optional[int] = None,
|
||||||
min_chunk_size_mb: float = 32,
|
min_chunk_size_mb: float = 32,
|
||||||
memstats: Optional[MemStats] = None) -> None:
|
memstats: Optional[MemStats] = None,
|
||||||
|
verbose: bool = False) -> None:
|
||||||
"""
|
"""
|
||||||
A torch.Module warpper using ZeRO-DP and Genimi.
|
A torch.Module warpper using ZeRO-DP and Genimi.
|
||||||
ZeRO is for parallel. Gemini is for memory management.
|
ZeRO is for parallel. Gemini is for memory management.
|
||||||
@ -604,6 +605,7 @@ class GeminiDDP(ZeroDDP):
|
|||||||
hidden_dim=hidden_dim,
|
hidden_dim=hidden_dim,
|
||||||
search_range_mb=search_range_mb,
|
search_range_mb=search_range_mb,
|
||||||
min_chunk_size_mb=min_chunk_size_mb,
|
min_chunk_size_mb=min_chunk_size_mb,
|
||||||
strict_ddp_flag=strict_ddp_mode)
|
strict_ddp_flag=strict_ddp_mode,
|
||||||
|
verbose=verbose)
|
||||||
gemini_manager = GeminiManager(placement_policy, chunk_manager, memstats)
|
gemini_manager = GeminiManager(placement_policy, chunk_manager, memstats)
|
||||||
super().__init__(module, gemini_manager, pin_memory, force_outputs_fp32, strict_ddp_mode)
|
super().__init__(module, gemini_manager, pin_memory, force_outputs_fp32, strict_ddp_mode)
|
||||||
|
@ -54,6 +54,7 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
|||||||
clipping_norm (float, optional): The norm value used to clip gradient. Defaults to 0.0.
|
clipping_norm (float, optional): The norm value used to clip gradient. Defaults to 0.0.
|
||||||
norm_type (float, optional): The type of norm used for gradient clipping. Currently, only L2-norm (norm_type=2.0)
|
norm_type (float, optional): The type of norm used for gradient clipping. Currently, only L2-norm (norm_type=2.0)
|
||||||
is supported in ZeroOptimizer. Defaults to 2.0.
|
is supported in ZeroOptimizer. Defaults to 2.0.
|
||||||
|
verbose (bool, optional): Whether to print verbose information, including grad overflow info. Defaults to False.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
@ -69,6 +70,7 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
|||||||
max_scale: float = 2**32,
|
max_scale: float = 2**32,
|
||||||
clipping_norm: float = 0.0,
|
clipping_norm: float = 0.0,
|
||||||
norm_type: float = 2.0,
|
norm_type: float = 2.0,
|
||||||
|
verbose: bool = False,
|
||||||
**defaults: Any):
|
**defaults: Any):
|
||||||
super().__init__(optim)
|
super().__init__(optim)
|
||||||
assert isinstance(module, ZeroDDP)
|
assert isinstance(module, ZeroDDP)
|
||||||
@ -83,6 +85,7 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
|||||||
self.chunk16_set: Set[Chunk] = set()
|
self.chunk16_set: Set[Chunk] = set()
|
||||||
self.clipping_flag = clipping_norm > 0.0
|
self.clipping_flag = clipping_norm > 0.0
|
||||||
self.max_norm = clipping_norm
|
self.max_norm = clipping_norm
|
||||||
|
self.verbose = verbose
|
||||||
|
|
||||||
if self.clipping_flag:
|
if self.clipping_flag:
|
||||||
assert norm_type == 2.0, "ZeroOptimizer only supports L2 norm now"
|
assert norm_type == 2.0, "ZeroOptimizer only supports L2 norm now"
|
||||||
@ -221,7 +224,8 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
|||||||
if found_inf:
|
if found_inf:
|
||||||
self.optim_state = OptimState.UNSCALED # no need to unscale grad
|
self.optim_state = OptimState.UNSCALED # no need to unscale grad
|
||||||
self.grad_scaler.update(found_inf) # update gradient scaler
|
self.grad_scaler.update(found_inf) # update gradient scaler
|
||||||
self._logger.info(f'Found overflow. Skip step')
|
if self.verbose:
|
||||||
|
self._logger.info(f'Found overflow. Skip step')
|
||||||
self._clear_global_norm() # clear recorded norm
|
self._clear_global_norm() # clear recorded norm
|
||||||
self.zero_grad() # reset all gradients
|
self.zero_grad() # reset all gradients
|
||||||
self._update_fp16_params()
|
self._update_fp16_params()
|
||||||
|
@ -440,6 +440,8 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
|
|||||||
# update loss scale if overflow occurs
|
# update loss scale if overflow occurs
|
||||||
if found_inf:
|
if found_inf:
|
||||||
self._grad_store.reset_all_average_gradients()
|
self._grad_store.reset_all_average_gradients()
|
||||||
|
if self._verbose:
|
||||||
|
self._logger.info(f'Found overflow. Skip step')
|
||||||
self.zero_grad()
|
self.zero_grad()
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -7,7 +7,10 @@ import torch.nn as nn
|
|||||||
from .gemini import GeminiDDP
|
from .gemini import GeminiDDP
|
||||||
|
|
||||||
|
|
||||||
def zero_model_wrapper(model: nn.Module, zero_stage: int = 1, gemini_config: Optional[Dict] = None):
|
def zero_model_wrapper(model: nn.Module,
|
||||||
|
zero_stage: int = 1,
|
||||||
|
gemini_config: Optional[Dict] = None,
|
||||||
|
verbose: bool = False):
|
||||||
"""This wrapper function is used to wrap your training model for ZeRO DDP.
|
"""This wrapper function is used to wrap your training model for ZeRO DDP.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
@ -40,7 +43,7 @@ def zero_model_wrapper(model: nn.Module, zero_stage: int = 1, gemini_config: Opt
|
|||||||
if zero_stage in [1, 2]:
|
if zero_stage in [1, 2]:
|
||||||
wrapped_model = model
|
wrapped_model = model
|
||||||
else:
|
else:
|
||||||
wrapped_model = GeminiDDP(model, **gemini_config)
|
wrapped_model = GeminiDDP(model, **gemini_config, verbose=verbose)
|
||||||
|
|
||||||
setattr(wrapped_model, "_colo_zero_stage", zero_stage)
|
setattr(wrapped_model, "_colo_zero_stage", zero_stage)
|
||||||
|
|
||||||
@ -58,7 +61,8 @@ def zero_optim_wrapper(model: nn.Module,
|
|||||||
max_scale: float = 2**32,
|
max_scale: float = 2**32,
|
||||||
max_norm: float = 0.0,
|
max_norm: float = 0.0,
|
||||||
norm_type: float = 2.0,
|
norm_type: float = 2.0,
|
||||||
optim_config: Optional[Dict] = None):
|
optim_config: Optional[Dict] = None,
|
||||||
|
verbose: bool = False):
|
||||||
"""This wrapper function is used to wrap your training optimizer for ZeRO DDP.
|
"""This wrapper function is used to wrap your training optimizer for ZeRO DDP.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -79,6 +83,7 @@ def zero_optim_wrapper(model: nn.Module,
|
|||||||
|
|
||||||
>>> zero2_config = dict(reduce_bucket_size=12 * 1024 * 1024, overlap_communication=True)
|
>>> zero2_config = dict(reduce_bucket_size=12 * 1024 * 1024, overlap_communication=True)
|
||||||
>>> optim = zero_optim_wrapper(model, optim, optim_config=zero2_config)
|
>>> optim = zero_optim_wrapper(model, optim, optim_config=zero2_config)
|
||||||
|
verbose (bool, optional): Whether to print the verbose info.
|
||||||
"""
|
"""
|
||||||
assert hasattr(model, "_colo_zero_stage"), "You should use `zero_ddp_wrapper` first"
|
assert hasattr(model, "_colo_zero_stage"), "You should use `zero_ddp_wrapper` first"
|
||||||
zero_stage = getattr(model, "_colo_zero_stage")
|
zero_stage = getattr(model, "_colo_zero_stage")
|
||||||
@ -102,8 +107,8 @@ def zero_optim_wrapper(model: nn.Module,
|
|||||||
from colossalai.zero.low_level import LowLevelZeroOptimizer
|
from colossalai.zero.low_level import LowLevelZeroOptimizer
|
||||||
config_dict['partition_grad'] = zero_stage == 2
|
config_dict['partition_grad'] = zero_stage == 2
|
||||||
config_dict['clip_grad_norm'] = max_norm
|
config_dict['clip_grad_norm'] = max_norm
|
||||||
return LowLevelZeroOptimizer(optimizer, **config_dict)
|
return LowLevelZeroOptimizer(optimizer, **config_dict, verbose=verbose)
|
||||||
else:
|
else:
|
||||||
from colossalai.zero.gemini.gemini_optimizer import ZeroOptimizer
|
from colossalai.zero.gemini.gemini_optimizer import ZeroOptimizer
|
||||||
config_dict['clipping_norm'] = max_norm
|
config_dict['clipping_norm'] = max_norm
|
||||||
return ZeroOptimizer(optimizer, model, **config_dict)
|
return ZeroOptimizer(optimizer, model, **config_dict, verbose=verbose)
|
||||||
|
@ -7,7 +7,7 @@ import os
|
|||||||
import time
|
import time
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List
|
from typing import List, Optional
|
||||||
|
|
||||||
from .utils import check_cuda_availability, check_system_pytorch_cuda_match, print_rank_0
|
from .utils import check_cuda_availability, check_system_pytorch_cuda_match, print_rank_0
|
||||||
|
|
||||||
@ -138,7 +138,7 @@ class Builder(ABC):
|
|||||||
# make sure system CUDA and pytorch CUDA match, an error will raised inside the function if not
|
# make sure system CUDA and pytorch CUDA match, an error will raised inside the function if not
|
||||||
check_system_pytorch_cuda_match(CUDA_HOME)
|
check_system_pytorch_cuda_match(CUDA_HOME)
|
||||||
|
|
||||||
def load(self, verbose=True):
|
def load(self, verbose: Optional[bool] = None):
|
||||||
"""
|
"""
|
||||||
load the kernel during runtime. If the kernel is not built during pip install, it will build the kernel.
|
load the kernel during runtime. If the kernel is not built during pip install, it will build the kernel.
|
||||||
If the kernel is built during runtime, it will be stored in `~/.cache/colossalai/torch_extensions/`. If the
|
If the kernel is built during runtime, it will be stored in `~/.cache/colossalai/torch_extensions/`. If the
|
||||||
@ -149,6 +149,8 @@ class Builder(ABC):
|
|||||||
Args:
|
Args:
|
||||||
verbose (bool, optional): show detailed info. Defaults to True.
|
verbose (bool, optional): show detailed info. Defaults to True.
|
||||||
"""
|
"""
|
||||||
|
if verbose is None:
|
||||||
|
verbose = os.environ.get('CAI_KERNEL_VERBOSE', '0') == '1'
|
||||||
# if the kernel has be compiled and cached, we directly use it
|
# if the kernel has be compiled and cached, we directly use it
|
||||||
if self.cached_op_module is not None:
|
if self.cached_op_module is not None:
|
||||||
return self.cached_op_module
|
return self.cached_op_module
|
||||||
|
@ -90,7 +90,6 @@ def check_system_pytorch_cuda_match(cuda_dir):
|
|||||||
'Please make sure you have set the CUDA_HOME correctly and installed the correct PyTorch in https://pytorch.org/get-started/locally/ .'
|
'Please make sure you have set the CUDA_HOME correctly and installed the correct PyTorch in https://pytorch.org/get-started/locally/ .'
|
||||||
)
|
)
|
||||||
|
|
||||||
print(bare_metal_minor != torch_cuda_minor)
|
|
||||||
if bare_metal_minor != torch_cuda_minor:
|
if bare_metal_minor != torch_cuda_minor:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
f"[extension] The CUDA version on the system ({bare_metal_major}.{bare_metal_minor}) does not match with the version ({torch_cuda_major}.{torch_cuda_minor}) torch was compiled with. "
|
f"[extension] The CUDA version on the system ({bare_metal_major}.{bare_metal_minor}) does not match with the version ({torch_cuda_major}.{torch_cuda_minor}) torch was compiled with. "
|
||||||
@ -156,16 +155,15 @@ def set_cuda_arch_list(cuda_dir):
|
|||||||
|
|
||||||
# we only need to set this when CUDA is not available for cross-compilation
|
# we only need to set this when CUDA is not available for cross-compilation
|
||||||
if not cuda_available:
|
if not cuda_available:
|
||||||
warnings.warn(
|
warnings.warn('\n[extension] PyTorch did not find available GPUs on this system.\n'
|
||||||
'\n[extension] PyTorch did not find available GPUs on this system.\n'
|
'If your intention is to cross-compile, this is not an error.\n'
|
||||||
'If your intention is to cross-compile, this is not an error.\n'
|
'By default, Colossal-AI will cross-compile for \n'
|
||||||
'By default, Colossal-AI will cross-compile for \n'
|
'1. Pascal (compute capabilities 6.0, 6.1, 6.2),\n'
|
||||||
'1. Pascal (compute capabilities 6.0, 6.1, 6.2),\n'
|
'2. Volta (compute capability 7.0)\n'
|
||||||
'2. Volta (compute capability 7.0)\n'
|
'3. Turing (compute capability 7.5),\n'
|
||||||
'3. Turing (compute capability 7.5),\n'
|
'4. Ampere (compute capability 8.0, 8.6)if the CUDA version is >= 11.0\n'
|
||||||
'4. Ampere (compute capability 8.0, 8.6)if the CUDA version is >= 11.0\n'
|
'\nIf you wish to cross-compile for a single specific architecture,\n'
|
||||||
'\nIf you wish to cross-compile for a single specific architecture,\n'
|
'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.\n')
|
||||||
'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.\n')
|
|
||||||
|
|
||||||
if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None:
|
if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None:
|
||||||
bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(cuda_dir)
|
bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(cuda_dir)
|
||||||
|
Loading…
Reference in New Issue
Block a user