mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 17:46:42 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -2,4 +2,4 @@ from .base_grad_scaler import BaseGradScaler
|
||||
from .constant_grad_scaler import ConstantGradScaler
|
||||
from .dynamic_grad_scaler import DynamicGradScaler
|
||||
|
||||
__all__ = ['BaseGradScaler', 'ConstantGradScaler', 'DynamicGradScaler']
|
||||
__all__ = ["BaseGradScaler", "ConstantGradScaler", "DynamicGradScaler"]
|
||||
|
@@ -9,7 +9,7 @@ from torch import Tensor
|
||||
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
__all__ = ['BaseGradScaler']
|
||||
__all__ = ["BaseGradScaler"]
|
||||
|
||||
|
||||
class BaseGradScaler(ABC):
|
||||
@@ -30,24 +30,21 @@ class BaseGradScaler(ABC):
|
||||
|
||||
@property
|
||||
def scale(self) -> Tensor:
|
||||
"""Returns the loss scale.
|
||||
"""
|
||||
"""Returns the loss scale."""
|
||||
|
||||
return self._scale
|
||||
|
||||
@property
|
||||
def inv_scale(self) -> Tensor:
|
||||
"""Returns the inverse of the loss scale.
|
||||
"""
|
||||
"""Returns the inverse of the loss scale."""
|
||||
|
||||
return self._scale.double().reciprocal().float()
|
||||
|
||||
def state_dict(self) -> Dict:
|
||||
"""Returns the states of the gradient scaler as a dict object.
|
||||
"""
|
||||
"""Returns the states of the gradient scaler as a dict object."""
|
||||
|
||||
state_dict = dict()
|
||||
state_dict['scale'] = self.scale
|
||||
state_dict["scale"] = self.scale
|
||||
return state_dict
|
||||
|
||||
def load_state_dict(self, state_dict: Dict) -> None:
|
||||
@@ -57,7 +54,7 @@ class BaseGradScaler(ABC):
|
||||
state_dict (dict): the states of the gradient scaler
|
||||
"""
|
||||
|
||||
self._scale = state_dict['scale']
|
||||
self._scale = state_dict["scale"]
|
||||
|
||||
@abstractmethod
|
||||
def update(self, overflow: bool) -> None:
|
||||
@@ -67,8 +64,6 @@ class BaseGradScaler(ABC):
|
||||
overflow (bool): whether overflow occurs
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
def log(self, message, *args, **kwargs):
|
||||
"""Log messages.
|
||||
|
||||
|
@@ -2,7 +2,7 @@
|
||||
# -*- encoding: utf-8 -*-
|
||||
from .base_grad_scaler import BaseGradScaler
|
||||
|
||||
__all__ = ['ConstantGradScaler']
|
||||
__all__ = ["ConstantGradScaler"]
|
||||
|
||||
|
||||
class ConstantGradScaler(BaseGradScaler):
|
||||
@@ -23,4 +23,3 @@ class ConstantGradScaler(BaseGradScaler):
|
||||
Args:
|
||||
overflow (bool): whether overflow occurs
|
||||
"""
|
||||
pass
|
||||
|
@@ -7,7 +7,7 @@ import torch
|
||||
|
||||
from .base_grad_scaler import BaseGradScaler
|
||||
|
||||
__all__ = ['DynamicGradScaler']
|
||||
__all__ = ["DynamicGradScaler"]
|
||||
|
||||
|
||||
class DynamicGradScaler(BaseGradScaler):
|
||||
@@ -24,15 +24,17 @@ class DynamicGradScaler(BaseGradScaler):
|
||||
verbose (bool): whether to log messages, defaults to False
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
initial_scale: float = 2**16,
|
||||
growth_factor: float = 2,
|
||||
backoff_factor: float = 0.5,
|
||||
growth_interval: int = 1000,
|
||||
min_scale: Optional[float] = None,
|
||||
max_scale: Optional[float] = None,
|
||||
hysteresis: int = 2,
|
||||
verbose: bool = False):
|
||||
def __init__(
|
||||
self,
|
||||
initial_scale: float = 2**16,
|
||||
growth_factor: float = 2,
|
||||
backoff_factor: float = 0.5,
|
||||
growth_interval: int = 1000,
|
||||
min_scale: Optional[float] = None,
|
||||
max_scale: Optional[float] = None,
|
||||
hysteresis: int = 2,
|
||||
verbose: bool = False,
|
||||
):
|
||||
super().__init__(initial_scale, verbose)
|
||||
if min_scale:
|
||||
self._min_scale = torch.cuda.FloatTensor([min_scale])
|
||||
@@ -53,18 +55,17 @@ class DynamicGradScaler(BaseGradScaler):
|
||||
self._sanity_checks()
|
||||
|
||||
def _sanity_checks(self) -> None:
|
||||
"""Check if the arguments are correct.
|
||||
"""
|
||||
"""Check if the arguments are correct."""
|
||||
|
||||
if self._min_scale:
|
||||
assert self._min_scale > 0, 'The minimum gradient scale cannot be zero or negative'
|
||||
assert self._min_scale <= self._scale, 'The minimum gradient scale cannot be greater than the current scale'
|
||||
assert self._min_scale > 0, "The minimum gradient scale cannot be zero or negative"
|
||||
assert self._min_scale <= self._scale, "The minimum gradient scale cannot be greater than the current scale"
|
||||
if self._max_scale:
|
||||
assert self._max_scale > 0, 'The maximum gradient scale cannot be zero or negative'
|
||||
assert self._max_scale >= self._scale, 'The maximum gradient scale cannot be smaller than the current scale'
|
||||
assert self._growth_factor > 1, 'The growth factor cannot be equal or smaller than 1'
|
||||
assert 0 < self._backoff_factor < 1, 'The backoff factor must be between 0 and 1'
|
||||
assert self._hysteresis >= 0, 'The hysteresis cannot be negative'
|
||||
assert self._max_scale > 0, "The maximum gradient scale cannot be zero or negative"
|
||||
assert self._max_scale >= self._scale, "The maximum gradient scale cannot be smaller than the current scale"
|
||||
assert self._growth_factor > 1, "The growth factor cannot be equal or smaller than 1"
|
||||
assert 0 < self._backoff_factor < 1, "The backoff factor must be between 0 and 1"
|
||||
assert self._hysteresis >= 0, "The hysteresis cannot be negative"
|
||||
|
||||
def update(self, overflow: bool) -> None:
|
||||
"""Update the loss scale.
|
||||
@@ -88,19 +89,18 @@ class DynamicGradScaler(BaseGradScaler):
|
||||
self.log(
|
||||
f"No overflow for consecutive {self._growth_interval} steps, "
|
||||
f"the loss scale is adjusted to {self.scale.item()}",
|
||||
ranks=[0])
|
||||
ranks=[0],
|
||||
)
|
||||
|
||||
def _backoff_scale(self) -> None:
|
||||
"""Decrease the loss scale
|
||||
"""
|
||||
"""Decrease the loss scale"""
|
||||
|
||||
self._scale = self._scale * self._backoff_factor
|
||||
if self._min_scale:
|
||||
self._scale = torch.max(self._scale, self._min_scale)
|
||||
|
||||
def _grow_scale(self) -> None:
|
||||
"""Increase the loss scale
|
||||
"""
|
||||
"""Increase the loss scale"""
|
||||
|
||||
self._scale = self._scale * self._growth_factor
|
||||
if self._max_scale:
|
||||
@@ -108,14 +108,14 @@ class DynamicGradScaler(BaseGradScaler):
|
||||
|
||||
def state_dict(self):
|
||||
state_dict = dict()
|
||||
state_dict['scale'] = self._scale
|
||||
state_dict['growth_factor'] = self._growth_factor
|
||||
state_dict['backoff_factor'] = self._backoff_factor
|
||||
state_dict['hysteresis'] = self._hysteresis
|
||||
state_dict["scale"] = self._scale
|
||||
state_dict["growth_factor"] = self._growth_factor
|
||||
state_dict["backoff_factor"] = self._backoff_factor
|
||||
state_dict["hysteresis"] = self._hysteresis
|
||||
return state_dict
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
self._scale = state_dict['scale'].cuda(torch.cuda.current_device())
|
||||
self._growth_factor = state_dict['growth_factor']
|
||||
self._backoff_factor = state_dict['backoff_factor']
|
||||
self._hysteresis = state_dict['hysteresis']
|
||||
self._scale = state_dict["scale"].cuda(torch.cuda.current_device())
|
||||
self._growth_factor = state_dict["growth_factor"]
|
||||
self._backoff_factor = state_dict["backoff_factor"]
|
||||
self._hysteresis = state_dict["hysteresis"]
|
||||
|
@@ -3,7 +3,7 @@ from .bf16 import BF16MixedPrecisionMixin
|
||||
from .fp16 import FP16MixedPrecisionMixin
|
||||
|
||||
__all__ = [
|
||||
'MixedPrecisionMixin',
|
||||
'FP16MixedPrecisionMixin',
|
||||
'BF16MixedPrecisionMixin',
|
||||
"MixedPrecisionMixin",
|
||||
"FP16MixedPrecisionMixin",
|
||||
"BF16MixedPrecisionMixin",
|
||||
]
|
||||
|
@@ -39,6 +39,7 @@ class MixedPrecisionMixin(ABC):
|
||||
return self.optim.zero_grad()
|
||||
```
|
||||
"""
|
||||
|
||||
dtype: torch.dtype
|
||||
|
||||
@abstractmethod
|
||||
@@ -51,7 +52,6 @@ class MixedPrecisionMixin(ABC):
|
||||
Returns:
|
||||
Tensor: Loss value (possibly scaled).
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def pre_backward_by_grad(self, tensor: Tensor, grad: Tensor) -> Tensor:
|
||||
@@ -64,7 +64,6 @@ class MixedPrecisionMixin(ABC):
|
||||
Returns:
|
||||
Tensor: Gradient of the tensor (possibly scaled).
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def should_skip_step(self) -> bool:
|
||||
@@ -73,13 +72,10 @@ class MixedPrecisionMixin(ABC):
|
||||
Returns:
|
||||
bool: Whether to skip the step.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def pre_zero_grad(self) -> None:
|
||||
"""Called before zero_grad.
|
||||
"""
|
||||
pass
|
||||
"""Called before zero_grad."""
|
||||
|
||||
@abstractmethod
|
||||
def get_grad_div_scale(self) -> float:
|
||||
@@ -88,4 +84,3 @@ class MixedPrecisionMixin(ABC):
|
||||
Returns:
|
||||
float: A divisor for gradient clipping or step.
|
||||
"""
|
||||
pass
|
||||
|
@@ -19,22 +19,26 @@ class OptimState(Enum):
|
||||
class FP16MixedPrecisionMixin(MixedPrecisionMixin):
|
||||
dtype = torch.float16
|
||||
|
||||
def __init__(self,
|
||||
initial_scale: float = 2**16,
|
||||
min_scale: float = 1,
|
||||
growth_factor: float = 2,
|
||||
backoff_factor: float = 0.5,
|
||||
growth_interval: int = 1000,
|
||||
hysteresis: int = 2,
|
||||
max_scale: float = 2**32) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
initial_scale: float = 2**16,
|
||||
min_scale: float = 1,
|
||||
growth_factor: float = 2,
|
||||
backoff_factor: float = 0.5,
|
||||
growth_interval: int = 1000,
|
||||
hysteresis: int = 2,
|
||||
max_scale: float = 2**32,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.grad_scaler = DynamicGradScaler(initial_scale=initial_scale,
|
||||
min_scale=min_scale,
|
||||
growth_factor=growth_factor,
|
||||
backoff_factor=backoff_factor,
|
||||
growth_interval=growth_interval,
|
||||
hysteresis=hysteresis,
|
||||
max_scale=max_scale)
|
||||
self.grad_scaler = DynamicGradScaler(
|
||||
initial_scale=initial_scale,
|
||||
min_scale=min_scale,
|
||||
growth_factor=growth_factor,
|
||||
backoff_factor=backoff_factor,
|
||||
growth_interval=growth_interval,
|
||||
hysteresis=hysteresis,
|
||||
max_scale=max_scale,
|
||||
)
|
||||
self.optim_state = OptimState.UNSCALED
|
||||
self.found_overflow = torch.zeros(1, dtype=torch.float, device=get_current_device())
|
||||
|
||||
@@ -49,7 +53,6 @@ class FP16MixedPrecisionMixin(MixedPrecisionMixin):
|
||||
Returns:
|
||||
bool: Whether there is overflow in the local process.
|
||||
"""
|
||||
pass
|
||||
|
||||
def check_overflow(self) -> bool:
|
||||
# clear previous overflow record
|
||||
@@ -79,6 +82,6 @@ class FP16MixedPrecisionMixin(MixedPrecisionMixin):
|
||||
pass
|
||||
|
||||
def get_grad_div_scale(self) -> float:
|
||||
assert self.optim_state == OptimState.SCALED, 'grads should be scaled before clipping'
|
||||
assert self.optim_state == OptimState.SCALED, "grads should be scaled before clipping"
|
||||
self.optim_state = OptimState.UNSCALED
|
||||
return self.loss_scale
|
||||
|
@@ -11,18 +11,20 @@ from .mixed_precision_mixin import BF16MixedPrecisionMixin, FP16MixedPrecisionMi
|
||||
|
||||
|
||||
class NaiveFP16MixedPrecisionMixin(FP16MixedPrecisionMixin):
|
||||
|
||||
def __init__(self,
|
||||
working_params: List[Parameter],
|
||||
initial_scale: float = 2**16,
|
||||
min_scale: float = 1,
|
||||
growth_factor: float = 2,
|
||||
backoff_factor: float = 0.5,
|
||||
growth_interval: int = 1000,
|
||||
hysteresis: int = 2,
|
||||
max_scale: float = 2**32) -> None:
|
||||
super().__init__(initial_scale, min_scale, growth_factor, backoff_factor, growth_interval, hysteresis,
|
||||
max_scale)
|
||||
def __init__(
|
||||
self,
|
||||
working_params: List[Parameter],
|
||||
initial_scale: float = 2**16,
|
||||
min_scale: float = 1,
|
||||
growth_factor: float = 2,
|
||||
backoff_factor: float = 0.5,
|
||||
growth_interval: int = 1000,
|
||||
hysteresis: int = 2,
|
||||
max_scale: float = 2**32,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
initial_scale, min_scale, growth_factor, backoff_factor, growth_interval, hysteresis, max_scale
|
||||
)
|
||||
self.params = working_params
|
||||
|
||||
def check_local_overflow(self) -> bool:
|
||||
@@ -33,38 +35,41 @@ class NaiveFP16MixedPrecisionMixin(FP16MixedPrecisionMixin):
|
||||
|
||||
|
||||
class MixedPrecisionOptimizer(OptimizerWrapper):
|
||||
|
||||
def __init__(self,
|
||||
optim: Optimizer,
|
||||
precision: str = 'fp16',
|
||||
initial_scale: float = 2**16,
|
||||
min_scale: float = 1,
|
||||
growth_factor: float = 2,
|
||||
backoff_factor: float = 0.5,
|
||||
growth_interval: int = 1000,
|
||||
hysteresis: int = 2,
|
||||
max_scale: float = 2**32,
|
||||
max_norm: float = 0.0):
|
||||
def __init__(
|
||||
self,
|
||||
optim: Optimizer,
|
||||
precision: str = "fp16",
|
||||
initial_scale: float = 2**16,
|
||||
min_scale: float = 1,
|
||||
growth_factor: float = 2,
|
||||
backoff_factor: float = 0.5,
|
||||
growth_interval: int = 1000,
|
||||
hysteresis: int = 2,
|
||||
max_scale: float = 2**32,
|
||||
max_norm: float = 0.0,
|
||||
):
|
||||
super().__init__(optim)
|
||||
if precision == 'fp16':
|
||||
if precision == "fp16":
|
||||
working_params = []
|
||||
for group in self.optim.param_groups:
|
||||
for p in group['params']:
|
||||
for p in group["params"]:
|
||||
working_params.append(p)
|
||||
self.mixed_precision = NaiveFP16MixedPrecisionMixin(working_params,
|
||||
initial_scale=initial_scale,
|
||||
min_scale=min_scale,
|
||||
growth_factor=growth_factor,
|
||||
backoff_factor=backoff_factor,
|
||||
growth_interval=growth_interval,
|
||||
hysteresis=hysteresis,
|
||||
max_scale=max_scale)
|
||||
elif precision == 'bf16':
|
||||
self.mixed_precision = NaiveFP16MixedPrecisionMixin(
|
||||
working_params,
|
||||
initial_scale=initial_scale,
|
||||
min_scale=min_scale,
|
||||
growth_factor=growth_factor,
|
||||
backoff_factor=backoff_factor,
|
||||
growth_interval=growth_interval,
|
||||
hysteresis=hysteresis,
|
||||
max_scale=max_scale,
|
||||
)
|
||||
elif precision == "bf16":
|
||||
self.mixed_precision = BF16MixedPrecisionMixin()
|
||||
else:
|
||||
raise ValueError(f'Unsupported precision: {precision}')
|
||||
raise ValueError(f"Unsupported precision: {precision}")
|
||||
if max_norm > 0.0:
|
||||
raise NotImplementedError('max_norm is not supported yet.')
|
||||
raise NotImplementedError("max_norm is not supported yet.")
|
||||
self.max_norm = max_norm
|
||||
self.working_to_master_map: Dict[Parameter, Tensor] = {}
|
||||
self.master_to_working_map: Dict[Tensor, Parameter] = {}
|
||||
@@ -72,7 +77,7 @@ class MixedPrecisionOptimizer(OptimizerWrapper):
|
||||
# create master weights
|
||||
for group in self.optim.param_groups:
|
||||
master_params = []
|
||||
for p in group['params']:
|
||||
for p in group["params"]:
|
||||
if p.requires_grad:
|
||||
master_p = p
|
||||
if p.dtype != torch.float:
|
||||
@@ -80,7 +85,7 @@ class MixedPrecisionOptimizer(OptimizerWrapper):
|
||||
self.working_to_master_map[p] = master_p
|
||||
self.master_to_working_map[master_p] = p
|
||||
master_params.append(master_p)
|
||||
group['params'] = master_params
|
||||
group["params"] = master_params
|
||||
|
||||
def backward(self, loss: Tensor, *args, **kwargs):
|
||||
loss = self.mixed_precision.pre_backward(loss)
|
||||
@@ -101,24 +106,24 @@ class MixedPrecisionOptimizer(OptimizerWrapper):
|
||||
if self.mixed_precision is not None:
|
||||
div_scale = self.mixed_precision.get_grad_div_scale()
|
||||
|
||||
if self.max_norm > 0.:
|
||||
if self.max_norm > 0.0:
|
||||
# norm is in fact norm*scale
|
||||
clip = ((total_norm / div_scale) + 1e-6) / self.max_norm
|
||||
if clip > 1:
|
||||
div_scale = clip * div_scale
|
||||
|
||||
for group in self.param_groups:
|
||||
for p in group['params']:
|
||||
for p in group["params"]:
|
||||
if p.grad is None:
|
||||
continue
|
||||
p.grad.data.mul_(1. / div_scale)
|
||||
p.grad.data.mul_(1.0 / div_scale)
|
||||
|
||||
def _compute_grad_norm(self) -> float:
|
||||
if self.max_norm <= 0.:
|
||||
return 0.
|
||||
grads = [p.grad for group in self.param_groups for p in group['params'] if p.grad is not None]
|
||||
if self.max_norm <= 0.0:
|
||||
return 0.0
|
||||
grads = [p.grad for group in self.param_groups for p in group["params"] if p.grad is not None]
|
||||
if len(grads) == 0:
|
||||
return 0.
|
||||
return 0.0
|
||||
device = grads[0].device
|
||||
# TODO(ver217): support tp
|
||||
total_norm = torch.norm(torch.stack([torch.norm(g.detach(), 2).to(device) for g in grads]), 2)
|
||||
@@ -130,7 +135,7 @@ class MixedPrecisionOptimizer(OptimizerWrapper):
|
||||
return
|
||||
# prepare grads
|
||||
for group in self.optim.param_groups:
|
||||
for p in group['params']:
|
||||
for p in group["params"]:
|
||||
working_param = self.master_to_working_map[p]
|
||||
if p is working_param:
|
||||
continue
|
||||
@@ -142,7 +147,7 @@ class MixedPrecisionOptimizer(OptimizerWrapper):
|
||||
self.optim.step(*args, **kwargs)
|
||||
# update working params
|
||||
for group in self.optim.param_groups:
|
||||
for p in group['params']:
|
||||
for p in group["params"]:
|
||||
working_param = self.master_to_working_map[p]
|
||||
if p is working_param:
|
||||
continue
|
||||
|
Reference in New Issue
Block a user