diff --git a/colossalai/amp/naive_amp/grad_scaler/__init__.py b/colossalai/amp/naive_amp/grad_scaler/__init__.py new file mode 100644 index 000000000..dc8499d87 --- /dev/null +++ b/colossalai/amp/naive_amp/grad_scaler/__init__.py @@ -0,0 +1,5 @@ +from .base_grad_scaler import BaseGradScaler +from .constant_grad_scaler import ConstantGradScaler +from .dynamic_grad_scaler import DynamicGradScaler + +__all__ = ['BaseGradScaler', 'ConstantGradScaler', 'DynamicGradScaler'] diff --git a/colossalai/amp/naive_amp/grad_scaler/base_grad_scaler.py b/colossalai/amp/naive_amp/grad_scaler/base_grad_scaler.py new file mode 100644 index 000000000..fb279baf6 --- /dev/null +++ b/colossalai/amp/naive_amp/grad_scaler/base_grad_scaler.py @@ -0,0 +1,46 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import torch +from abc import ABC, abstractmethod +from colossalai.logging import get_dist_logger +from torch import Tensor +from typing import Dict + +__all__ = ['BaseGradScaler'] + + +class BaseGradScaler(ABC): + + def __init__(self, initial_scale: int, verbose: bool): + assert initial_scale > 0 + self._scale = torch.cuda.FloatTensor([initial_scale]) + self._verbose = verbose + + if self._verbose: + self._logger = get_dist_logger() + + @property + def scale(self) -> Tensor: + return self._scale + + @property + def inv_scale(self) -> Tensor: + return self._scale.double().reciprocal().float() + + @abstractmethod + def state_dict(self) -> Dict: + state_dict = dict() + state_dict['scale'] = self.scale + + @abstractmethod + def load_state_dict(self, state_dict: Dict) -> None: + self._scale = state_dict['scale'] + + @abstractmethod + def update(self, overflow: bool) -> None: + pass + + def log(self, message, *args, **kwargs): + if self._verbose: + self._logger.info(message, *args, **kwargs) diff --git a/colossalai/amp/naive_amp/grad_scaler/constant_grad_scaler.py b/colossalai/amp/naive_amp/grad_scaler/constant_grad_scaler.py new file mode 100644 index 000000000..5f79462b3 --- /dev/null +++ b/colossalai/amp/naive_amp/grad_scaler/constant_grad_scaler.py @@ -0,0 +1,16 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- +from .base_grad_scaler import BaseGradScaler + +__all__ = ['ConstantGradScaler'] + + +class ConstantGradScaler(BaseGradScaler): + + def __init__(self, initial_scale: int, verbose: bool): + super().__init__(initial_scale, verbose) + self.log(f"Constant Gradient Scaler is initialized with scale {self.scale}", ranks=[0]) + + def update(self, overflow: bool) -> None: + # do nothing to maintain the current scale value + pass diff --git a/colossalai/amp/naive_amp/grad_scaler/dynamic_grad_scaler.py b/colossalai/amp/naive_amp/grad_scaler/dynamic_grad_scaler.py new file mode 100644 index 000000000..79fd0f3a3 --- /dev/null +++ b/colossalai/amp/naive_amp/grad_scaler/dynamic_grad_scaler.py @@ -0,0 +1,68 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import torch +from .base_grad_scaler import BaseGradScaler + +__all__ = ['DynamicGradScaler'] + + +class DynamicGradScaler(BaseGradScaler): + + def __init__(self, + initial_scale: int = 2**16, + growth_factor: int = 2, + backoff_factor: float = 0.5, + growth_interval: int = 1000, + min_scale: int = None, + max_scale: int = None, + hysteresis: int = None, + verbose: bool = False): + super().__init__(initial_scale, verbose) + self._min_scale = min_scale + self._max_scale = max_scale + self._growth_factor = growth_factor + self._backoff_factor = backoff_factor + self._growth_interval = growth_interval + self._growth_step = 0 + self._hysteresis = hysteresis + self._hysteresis_step = 0 + self._sanity_checks() + + def _sanity_checks(self) -> None: + if self._min_scale: + assert self._min_scale > 0, 'The minimum gradient scale cannot be zero or negative' + if self._max_scale: + assert self._min_scale > 0, 'The maximum gradient scale cannot be zero or negative' + assert self._growth_factor > 1, 'The growth factor cannot be equal or smaller than 1' + assert self._backoff_factor < 1 and self._backoff_factor > 0, 'The backoff factor must be between 0 and 1' + assert self._hysteresis >= 0, 'The hysteresis cannot be negative' + + def update(self, overflow: bool) -> None: + if overflow: + self._hysteresis_step += 1 + self._growth_step = 0 + + if self._hysteresis_step >= self._hysteresis: + self._backoff_scale() + self.log(f"Overflow occurs, the loss scale is adjusted to {self.scale.item()}", ranks=[0]) + else: + self._growth_step += 1 + if self._growth_step == self._growth_interval: + self._growth_step = 0 + self._hysteresis_step = 0 + self._grow_scale() + self.log( + f"No overflow for consecutive {self._growth_interval} steps, " + f"the loss scale is adjusted to {self.scale.item()}", + ranks=[0]) + + def _backoff_scale(self) -> None: + self._scale = self._scale * self._backoff_factor + if self._min_scale: + self._scale = torch.max(self._scale, self._min_scale) + + def _grow_scale(self) -> None: + self._scale = self._scale * self._growth_factor + if self._max_scale: + self._scale = torch.min(self._scale, self._max_scale)