diff --git a/colossalai/amp/naive_amp/grad_scaler/base_grad_scaler.py b/colossalai/amp/naive_amp/grad_scaler/base_grad_scaler.py index d27883a8e..0d84384a7 100644 --- a/colossalai/amp/naive_amp/grad_scaler/base_grad_scaler.py +++ b/colossalai/amp/naive_amp/grad_scaler/base_grad_scaler.py @@ -1,12 +1,14 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -import torch from abc import ABC, abstractmethod -from colossalai.logging import get_dist_logger -from torch import Tensor from typing import Dict +import torch +from torch import Tensor + +from colossalai.logging import get_dist_logger + __all__ = ['BaseGradScaler']