mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-16 06:30:41 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -6,16 +6,22 @@ from .fp16_torch import FP16TorchMixedPrecision
|
||||
from .mixed_precision_base import MixedPrecision
|
||||
|
||||
__all__ = [
|
||||
'MixedPrecision', 'mixed_precision_factory', 'FP16_Apex_MixedPrecision', 'FP16_Torch_MixedPrecision',
|
||||
'FP32_MixedPrecision', 'BF16_MixedPrecision', 'FP8_MixedPrecision', 'FP16NaiveMixedPrecision'
|
||||
"MixedPrecision",
|
||||
"mixed_precision_factory",
|
||||
"FP16_Apex_MixedPrecision",
|
||||
"FP16_Torch_MixedPrecision",
|
||||
"FP32_MixedPrecision",
|
||||
"BF16_MixedPrecision",
|
||||
"FP8_MixedPrecision",
|
||||
"FP16NaiveMixedPrecision",
|
||||
]
|
||||
|
||||
_mixed_precision_mapping = {
|
||||
'fp16': FP16TorchMixedPrecision,
|
||||
'fp16_apex': FP16ApexMixedPrecision,
|
||||
'fp16_naive': FP16NaiveMixedPrecision,
|
||||
'bf16': BF16MixedPrecision,
|
||||
'fp8': FP8MixedPrecision
|
||||
"fp16": FP16TorchMixedPrecision,
|
||||
"fp16_apex": FP16ApexMixedPrecision,
|
||||
"fp16_naive": FP16NaiveMixedPrecision,
|
||||
"bf16": BF16MixedPrecision,
|
||||
"fp8": FP8MixedPrecision,
|
||||
}
|
||||
|
||||
|
||||
@@ -31,5 +37,5 @@ def mixed_precision_factory(mixed_precision_type: str) -> MixedPrecision:
|
||||
return _mixed_precision_mapping[mixed_precision_type]()
|
||||
else:
|
||||
raise ValueError(
|
||||
f'Mixed precision type {mixed_precision_type} is not supported, support types include {list(_mixed_precision_mapping.keys())}'
|
||||
f"Mixed precision type {mixed_precision_type} is not supported, support types include {list(_mixed_precision_mapping.keys())}"
|
||||
)
|
||||
|
@@ -23,16 +23,18 @@ class FP16ApexMixedPrecision(MixedPrecision):
|
||||
max_loss_scale(float, default=2.**24 ): Sets a ceiling for the loss scale values that can be chosen by dynamic loss scaling. If dynamic loss scaling is not used, max_loss_scale is ignored.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
opt_level: Optional[str] = "O1",
|
||||
cast_model_type: torch.dtype = None,
|
||||
patch_torch_functions: bool = None,
|
||||
keep_batchnorm_fp32: Union[bool, str] = None,
|
||||
master_weights: bool = None,
|
||||
loss_scale: Union[float, str] = None,
|
||||
cast_model_outputs: Any = None,
|
||||
num_losses: Optional[int] = 1,
|
||||
verbosity: int = 1,
|
||||
min_loss_scale: float = None,
|
||||
max_loss_scale: float = 2.**24) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
opt_level: Optional[str] = "O1",
|
||||
cast_model_type: torch.dtype = None,
|
||||
patch_torch_functions: bool = None,
|
||||
keep_batchnorm_fp32: Union[bool, str] = None,
|
||||
master_weights: bool = None,
|
||||
loss_scale: Union[float, str] = None,
|
||||
cast_model_outputs: Any = None,
|
||||
num_losses: Optional[int] = 1,
|
||||
verbosity: int = 1,
|
||||
min_loss_scale: float = None,
|
||||
max_loss_scale: float = 2.0**24,
|
||||
) -> None:
|
||||
pass
|
||||
|
@@ -15,12 +15,14 @@ class FP16NaiveMixedPrecision(MixedPrecision):
|
||||
verbose(bool): if set to `True`, will print debug info.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
log_num_zeros_in_grad: bool,
|
||||
initial_scale: int,
|
||||
growth_factor: int,
|
||||
backoff_factor: float,
|
||||
hysteresis: int,
|
||||
max_scale: int,
|
||||
verbose: bool = None) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
log_num_zeros_in_grad: bool,
|
||||
initial_scale: int,
|
||||
growth_factor: int,
|
||||
backoff_factor: float,
|
||||
hysteresis: int,
|
||||
max_scale: int,
|
||||
verbose: bool = None,
|
||||
) -> None:
|
||||
pass
|
||||
|
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
from typing import Callable, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -9,7 +9,7 @@ from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||
|
||||
from .mixed_precision_base import MixedPrecision
|
||||
|
||||
__all__ = ['FP16_Torch_MixedPrecision', 'TorchAMPOptimizer', 'TorchAMPModule']
|
||||
__all__ = ["FP16_Torch_MixedPrecision", "TorchAMPOptimizer", "TorchAMPModule"]
|
||||
|
||||
|
||||
class TorchAMPOptimizer(OptimizerWrapper):
|
||||
@@ -29,17 +29,21 @@ class TorchAMPOptimizer(OptimizerWrapper):
|
||||
calls that may cause the scale to increase. Default: 2000.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
optim: Optimizer,
|
||||
init_scale: float = 2.**16,
|
||||
growth_factor: float = 2.0,
|
||||
backoff_factor: float = 0.5,
|
||||
growth_interval: int = 2000) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
optim: Optimizer,
|
||||
init_scale: float = 2.0**16,
|
||||
growth_factor: float = 2.0,
|
||||
backoff_factor: float = 0.5,
|
||||
growth_interval: int = 2000,
|
||||
) -> None:
|
||||
super().__init__(optim)
|
||||
self.scaler = torch.cuda.amp.GradScaler(init_scale=init_scale,
|
||||
growth_factor=growth_factor,
|
||||
backoff_factor=backoff_factor,
|
||||
growth_interval=growth_interval)
|
||||
self.scaler = torch.cuda.amp.GradScaler(
|
||||
init_scale=init_scale,
|
||||
growth_factor=growth_factor,
|
||||
backoff_factor=backoff_factor,
|
||||
growth_interval=growth_interval,
|
||||
)
|
||||
|
||||
def backward(self, loss: Tensor, *args, **kwargs) -> None:
|
||||
scaled_loss = self.scale_loss(loss)
|
||||
@@ -60,12 +64,14 @@ class TorchAMPOptimizer(OptimizerWrapper):
|
||||
self.unscale_grad()
|
||||
super().clip_grad_by_value(clip_value, *args, **kwargs)
|
||||
|
||||
def clip_grad_by_norm(self,
|
||||
max_norm: Union[float, int],
|
||||
norm_type: Union[float, int] = 2.0,
|
||||
error_if_nonfinite: bool = False,
|
||||
*args,
|
||||
**kwargs) -> None:
|
||||
def clip_grad_by_norm(
|
||||
self,
|
||||
max_norm: Union[float, int],
|
||||
norm_type: Union[float, int] = 2.0,
|
||||
error_if_nonfinite: bool = False,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
self.unscale_grad()
|
||||
super().clip_grad_by_norm(max_norm, norm_type, error_if_nonfinite, *args, **kwargs)
|
||||
|
||||
@@ -102,22 +108,27 @@ class FP16TorchMixedPrecision(MixedPrecision):
|
||||
calls that may cause the scale to increase. Default: 2000.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
init_scale: float = 2.**16,
|
||||
growth_factor: float = 2.0,
|
||||
backoff_factor: float = 0.5,
|
||||
growth_interval: int = 2000) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
init_scale: float = 2.0**16,
|
||||
growth_factor: float = 2.0,
|
||||
backoff_factor: float = 0.5,
|
||||
growth_interval: int = 2000,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.torch_amp_kwargs = dict(init_scale=init_scale,
|
||||
growth_factor=growth_factor,
|
||||
backoff_factor=backoff_factor,
|
||||
growth_interval=growth_interval)
|
||||
self.torch_amp_kwargs = dict(
|
||||
init_scale=init_scale,
|
||||
growth_factor=growth_factor,
|
||||
backoff_factor=backoff_factor,
|
||||
growth_interval=growth_interval,
|
||||
)
|
||||
|
||||
def configure(self,
|
||||
model: nn.Module,
|
||||
optimizer: Optional[Optimizer] = None,
|
||||
criterion: Optional[Callable] = None,
|
||||
) -> Tuple[nn.Module, OptimizerWrapper, Callable]:
|
||||
def configure(
|
||||
self,
|
||||
model: nn.Module,
|
||||
optimizer: Optional[Optimizer] = None,
|
||||
criterion: Optional[Callable] = None,
|
||||
) -> Tuple[nn.Module, OptimizerWrapper, Callable]:
|
||||
model = TorchAMPModule(model)
|
||||
if optimizer is not None:
|
||||
optimizer = TorchAMPOptimizer(optimizer, **self.torch_amp_kwargs)
|
||||
|
Reference in New Issue
Block a user