diff --git a/colossalai/amp/torch_amp/__init__.py b/colossalai/amp/torch_amp/__init__.py index 8943b86d6..893cc890d 100644 --- a/colossalai/amp/torch_amp/__init__.py +++ b/colossalai/amp/torch_amp/__init__.py @@ -1,10 +1,13 @@ -import torch.nn as nn -from torch.optim import Optimizer -from torch.nn.modules.loss import _Loss -from colossalai.context import Config -from .torch_amp import TorchAMPOptimizer, TorchAMPModel, TorchAMPLoss from typing import Optional +import torch.nn as nn +from torch.nn.modules.loss import _Loss +from torch.optim import Optimizer + +from colossalai.context import Config + +from .torch_amp import TorchAMPLoss, TorchAMPModel, TorchAMPOptimizer + def convert_to_torch_amp(model: nn.Module, optimizer: Optimizer,