diff --git a/colossalai/amp/torch_amp/torch_amp.py b/colossalai/amp/torch_amp/torch_amp.py index 5074e9c81..65718d77c 100644 --- a/colossalai/amp/torch_amp/torch_amp.py +++ b/colossalai/amp/torch_amp/torch_amp.py @@ -1,17 +1,17 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -import torch.nn as nn import torch.cuda.amp as torch_amp - +import torch.nn as nn from torch import Tensor from torch.nn.modules.loss import _Loss from torch.optim import Optimizer -from ._grad_scaler import GradScaler from colossalai.nn.optimizer import ColossalaiOptimizer from colossalai.utils import clip_grad_norm_fp32 +from ._grad_scaler import GradScaler + class TorchAMPOptimizer(ColossalaiOptimizer): """A wrapper class which integrate Pytorch AMP with an optimizer