From c614a99d286087b5768cc6422b7317dcff02db3e Mon Sep 17 00:00:00 2001 From: Yanjia0 <42895286+Yanjia0@users.noreply.github.com> Date: Tue, 18 Jul 2023 10:54:55 +0800 Subject: [PATCH] [NFC] polish colossalai/auto_parallel/offload/amp_optimizer.py code style (#4255) --- colossalai/auto_parallel/offload/amp_optimizer.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/colossalai/auto_parallel/offload/amp_optimizer.py b/colossalai/auto_parallel/offload/amp_optimizer.py index a79e5006e..19d85b80d 100644 --- a/colossalai/auto_parallel/offload/amp_optimizer.py +++ b/colossalai/auto_parallel/offload/amp_optimizer.py @@ -1,24 +1,25 @@ -from typing import Dict, Tuple from enum import Enum +from typing import Dict, Tuple + import torch from torch.optim import Optimizer +from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler from colossalai.logging import get_dist_logger from colossalai.nn.optimizer import ColossalaiOptimizer -from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler from colossalai.utils import get_current_device from .base_offload_module import BaseOffloadModule -from .region_manager import RegionManager from .region import Region +from .region_manager import RegionManager class OptimState(Enum): SCALED = 0 UNSCALED = 1 -class AMPOptimizer(ColossalaiOptimizer): +class AMPOptimizer(ColossalaiOptimizer): """ A wrapper for Optimizer. Code reference: https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/optimizer/zero_optimizer.py @@ -174,4 +175,4 @@ class AMPOptimizer(ColossalaiOptimizer): # Leverage state_dict() and load_state_dict() to # recast preexisting per-param state tensors - self.optim.load_state_dict(self.optim.state_dict()) \ No newline at end of file + self.optim.load_state_dict(self.optim.state_dict())