diff --git a/colossalai/amp/__init__.py b/colossalai/amp/__init__.py index 16da81f23..963215476 100644 --- a/colossalai/amp/__init__.py +++ b/colossalai/amp/__init__.py @@ -1,14 +1,16 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -from .amp_type import AMP_TYPE -from colossalai.context import Config import torch.nn as nn -from torch.optim import Optimizer from torch.nn.modules.loss import _Loss -from .torch_amp import convert_to_torch_amp +from torch.optim import Optimizer + +from colossalai.context import Config + +from .amp_type import AMP_TYPE from .apex_amp import convert_to_apex_amp from .naive_amp import convert_to_naive_amp +from .torch_amp import convert_to_torch_amp __all__ = ['convert_to_amp', 'convert_to_naive_amp', 'convert_to_apex_amp', 'convert_to_torch_amp', 'AMP_TYPE']