[NFC] polish colossalai/auto_parallel/offload/amp_optimizer.py code style (#4255)

This commit is contained in:
Yanjia0 2023-07-18 10:54:55 +08:00 committed by binmakeswell
parent 85774f0c1f
commit c614a99d28

View File

@ -1,24 +1,25 @@
from typing import Dict, Tuple
from enum import Enum from enum import Enum
from typing import Dict, Tuple
import torch import torch
from torch.optim import Optimizer from torch.optim import Optimizer
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import ColossalaiOptimizer from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from .base_offload_module import BaseOffloadModule from .base_offload_module import BaseOffloadModule
from .region_manager import RegionManager
from .region import Region from .region import Region
from .region_manager import RegionManager
class OptimState(Enum): class OptimState(Enum):
SCALED = 0 SCALED = 0
UNSCALED = 1 UNSCALED = 1
class AMPOptimizer(ColossalaiOptimizer):
class AMPOptimizer(ColossalaiOptimizer):
""" """
A wrapper for Optimizer. A wrapper for Optimizer.
Code reference: https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/optimizer/zero_optimizer.py Code reference: https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/optimizer/zero_optimizer.py
@ -174,4 +175,4 @@ class AMPOptimizer(ColossalaiOptimizer):
# Leverage state_dict() and load_state_dict() to # Leverage state_dict() and load_state_dict() to
# recast preexisting per-param state tensors # recast preexisting per-param state tensors
self.optim.load_state_dict(self.optim.state_dict()) self.optim.load_state_dict(self.optim.state_dict())