diff --git a/colossalai/auto_parallel/offload/amp_optimizer.py b/colossalai/auto_parallel/offload/amp_optimizer.py index a79e5006e..19d85b80d 100644 --- a/colossalai/auto_parallel/offload/amp_optimizer.py +++ b/colossalai/auto_parallel/offload/amp_optimizer.py @@ -1,24 +1,25 @@ -from typing import Dict, Tuple from enum import Enum +from typing import Dict, Tuple + import torch from torch.optim import Optimizer +from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler from colossalai.logging import get_dist_logger from colossalai.nn.optimizer import ColossalaiOptimizer -from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler from colossalai.utils import get_current_device from .base_offload_module import BaseOffloadModule -from .region_manager import RegionManager from .region import Region +from .region_manager import RegionManager class OptimState(Enum): SCALED = 0 UNSCALED = 1 -class AMPOptimizer(ColossalaiOptimizer): +class AMPOptimizer(ColossalaiOptimizer): """ A wrapper for Optimizer. Code reference: https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/optimizer/zero_optimizer.py @@ -174,4 +175,4 @@ class AMPOptimizer(ColossalaiOptimizer): # Leverage state_dict() and load_state_dict() to # recast preexisting per-param state tensors - self.optim.load_state_dict(self.optim.state_dict()) \ No newline at end of file + self.optim.load_state_dict(self.optim.state_dict())