mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 11:02:05 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
from typing import Callable, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -9,7 +9,7 @@ from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||
|
||||
from .mixed_precision_base import MixedPrecision
|
||||
|
||||
__all__ = ['FP16_Torch_MixedPrecision', 'TorchAMPOptimizer', 'TorchAMPModule']
|
||||
__all__ = ["FP16_Torch_MixedPrecision", "TorchAMPOptimizer", "TorchAMPModule"]
|
||||
|
||||
|
||||
class TorchAMPOptimizer(OptimizerWrapper):
|
||||
@@ -29,17 +29,21 @@ class TorchAMPOptimizer(OptimizerWrapper):
|
||||
calls that may cause the scale to increase. Default: 2000.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
optim: Optimizer,
|
||||
init_scale: float = 2.**16,
|
||||
growth_factor: float = 2.0,
|
||||
backoff_factor: float = 0.5,
|
||||
growth_interval: int = 2000) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
optim: Optimizer,
|
||||
init_scale: float = 2.0**16,
|
||||
growth_factor: float = 2.0,
|
||||
backoff_factor: float = 0.5,
|
||||
growth_interval: int = 2000,
|
||||
) -> None:
|
||||
super().__init__(optim)
|
||||
self.scaler = torch.cuda.amp.GradScaler(init_scale=init_scale,
|
||||
growth_factor=growth_factor,
|
||||
backoff_factor=backoff_factor,
|
||||
growth_interval=growth_interval)
|
||||
self.scaler = torch.cuda.amp.GradScaler(
|
||||
init_scale=init_scale,
|
||||
growth_factor=growth_factor,
|
||||
backoff_factor=backoff_factor,
|
||||
growth_interval=growth_interval,
|
||||
)
|
||||
|
||||
def backward(self, loss: Tensor, *args, **kwargs) -> None:
|
||||
scaled_loss = self.scale_loss(loss)
|
||||
@@ -60,12 +64,14 @@ class TorchAMPOptimizer(OptimizerWrapper):
|
||||
self.unscale_grad()
|
||||
super().clip_grad_by_value(clip_value, *args, **kwargs)
|
||||
|
||||
def clip_grad_by_norm(self,
|
||||
max_norm: Union[float, int],
|
||||
norm_type: Union[float, int] = 2.0,
|
||||
error_if_nonfinite: bool = False,
|
||||
*args,
|
||||
**kwargs) -> None:
|
||||
def clip_grad_by_norm(
|
||||
self,
|
||||
max_norm: Union[float, int],
|
||||
norm_type: Union[float, int] = 2.0,
|
||||
error_if_nonfinite: bool = False,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
self.unscale_grad()
|
||||
super().clip_grad_by_norm(max_norm, norm_type, error_if_nonfinite, *args, **kwargs)
|
||||
|
||||
@@ -102,22 +108,27 @@ class FP16TorchMixedPrecision(MixedPrecision):
|
||||
calls that may cause the scale to increase. Default: 2000.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
init_scale: float = 2.**16,
|
||||
growth_factor: float = 2.0,
|
||||
backoff_factor: float = 0.5,
|
||||
growth_interval: int = 2000) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
init_scale: float = 2.0**16,
|
||||
growth_factor: float = 2.0,
|
||||
backoff_factor: float = 0.5,
|
||||
growth_interval: int = 2000,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.torch_amp_kwargs = dict(init_scale=init_scale,
|
||||
growth_factor=growth_factor,
|
||||
backoff_factor=backoff_factor,
|
||||
growth_interval=growth_interval)
|
||||
self.torch_amp_kwargs = dict(
|
||||
init_scale=init_scale,
|
||||
growth_factor=growth_factor,
|
||||
backoff_factor=backoff_factor,
|
||||
growth_interval=growth_interval,
|
||||
)
|
||||
|
||||
def configure(self,
|
||||
model: nn.Module,
|
||||
optimizer: Optional[Optimizer] = None,
|
||||
criterion: Optional[Callable] = None,
|
||||
) -> Tuple[nn.Module, OptimizerWrapper, Callable]:
|
||||
def configure(
|
||||
self,
|
||||
model: nn.Module,
|
||||
optimizer: Optional[Optimizer] = None,
|
||||
criterion: Optional[Callable] = None,
|
||||
) -> Tuple[nn.Module, OptimizerWrapper, Callable]:
|
||||
model = TorchAMPModule(model)
|
||||
if optimizer is not None:
|
||||
optimizer = TorchAMPOptimizer(optimizer, **self.torch_amp_kwargs)
|
||||
|
Reference in New Issue
Block a user