ColossalAI/colossalai/amp/naive_amp/grad_scaler/base_grad_scaler.py
Hongxin Liu d202cc28c0
[npu] change device to accelerator api (#5239)
* update accelerator

* fix timer

* fix amp

* update

* fix

* update bug

* add error raise

* fix autocast

* fix set device

* remove doc accelerator

* update doc

* update doc

* update doc

* use nullcontext

* update cpu

* update null context

* change time limit for example

* udpate

* update

* update

* update

* [npu] polish accelerator code

---------

Co-authored-by: Xuanlei Zhao <xuanlei.zhao@gmail.com>
Co-authored-by: zxl <43881818+oahzxl@users.noreply.github.com>
2024-01-09 10:20:05 +08:00

79 lines
2.1 KiB
Python

#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from abc import ABC, abstractmethod
from typing import Dict
import torch
from torch import Tensor
from colossalai.accelerator import get_accelerator
from colossalai.logging import get_dist_logger
__all__ = ["BaseGradScaler"]
class BaseGradScaler(ABC):
"""A base class for the gradient scaler.
Args:
initial_scale (float): the initial loss scale
verbose (bool): whether to log messages
"""
def __init__(self, initial_scale: float, verbose: bool):
assert initial_scale > 0
self._scale = torch.tensor([initial_scale], device=get_accelerator().get_current_device(), dtype=torch.float)
self._verbose = verbose
if self._verbose:
self._logger = get_dist_logger()
@property
def scale(self) -> Tensor:
"""Returns the loss scale."""
return self._scale
@property
def inv_scale(self) -> Tensor:
"""Returns the inverse of the loss scale."""
return self._scale.double().reciprocal().float()
def state_dict(self) -> Dict:
"""Returns the states of the gradient scaler as a dict object."""
state_dict = dict()
state_dict["scale"] = self.scale
return state_dict
def load_state_dict(self, state_dict: Dict) -> None:
"""Load the states of the gradient scaler from a dict object.
Args:
state_dict (dict): the states of the gradient scaler
"""
self._scale = state_dict["scale"]
@abstractmethod
def update(self, overflow: bool) -> None:
"""Update the loss scale.
Args:
overflow (bool): whether overflow occurs
"""
def log(self, message, *args, **kwargs):
"""Log messages.
Args:
message (str): the message to log
*args: positional arguments for :class:`colossalai.logging.DistributedLogger`
**kwargs: key-word arguments for :class:`colossalai.logging.DistributedLogger`
"""
if self._verbose:
self._logger.info(message, *args, **kwargs)