mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-26 20:23:26 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -9,7 +9,7 @@ from torch import Tensor
|
||||
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
__all__ = ['BaseGradScaler']
|
||||
__all__ = ["BaseGradScaler"]
|
||||
|
||||
|
||||
class BaseGradScaler(ABC):
|
||||
@@ -30,24 +30,21 @@ class BaseGradScaler(ABC):
|
||||
|
||||
@property
|
||||
def scale(self) -> Tensor:
|
||||
"""Returns the loss scale.
|
||||
"""
|
||||
"""Returns the loss scale."""
|
||||
|
||||
return self._scale
|
||||
|
||||
@property
|
||||
def inv_scale(self) -> Tensor:
|
||||
"""Returns the inverse of the loss scale.
|
||||
"""
|
||||
"""Returns the inverse of the loss scale."""
|
||||
|
||||
return self._scale.double().reciprocal().float()
|
||||
|
||||
def state_dict(self) -> Dict:
|
||||
"""Returns the states of the gradient scaler as a dict object.
|
||||
"""
|
||||
"""Returns the states of the gradient scaler as a dict object."""
|
||||
|
||||
state_dict = dict()
|
||||
state_dict['scale'] = self.scale
|
||||
state_dict["scale"] = self.scale
|
||||
return state_dict
|
||||
|
||||
def load_state_dict(self, state_dict: Dict) -> None:
|
||||
@@ -57,7 +54,7 @@ class BaseGradScaler(ABC):
|
||||
state_dict (dict): the states of the gradient scaler
|
||||
"""
|
||||
|
||||
self._scale = state_dict['scale']
|
||||
self._scale = state_dict["scale"]
|
||||
|
||||
@abstractmethod
|
||||
def update(self, overflow: bool) -> None:
|
||||
@@ -67,8 +64,6 @@ class BaseGradScaler(ABC):
|
||||
overflow (bool): whether overflow occurs
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
def log(self, message, *args, **kwargs):
|
||||
"""Log messages.
|
||||
|
||||
|
Reference in New Issue
Block a user