mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-20 12:43:55 +00:00
[NFC] polish colossalai/auto_parallel/offload/amp_optimizer.py code style (#4255)
This commit is contained in:
parent
85774f0c1f
commit
c614a99d28
@ -1,24 +1,25 @@
|
|||||||
from typing import Dict, Tuple
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
from typing import Dict, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
|
|
||||||
|
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
from colossalai.nn.optimizer import ColossalaiOptimizer
|
from colossalai.nn.optimizer import ColossalaiOptimizer
|
||||||
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
|
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
|
|
||||||
from .base_offload_module import BaseOffloadModule
|
from .base_offload_module import BaseOffloadModule
|
||||||
from .region_manager import RegionManager
|
|
||||||
from .region import Region
|
from .region import Region
|
||||||
|
from .region_manager import RegionManager
|
||||||
|
|
||||||
|
|
||||||
class OptimState(Enum):
|
class OptimState(Enum):
|
||||||
SCALED = 0
|
SCALED = 0
|
||||||
UNSCALED = 1
|
UNSCALED = 1
|
||||||
|
|
||||||
class AMPOptimizer(ColossalaiOptimizer):
|
|
||||||
|
|
||||||
|
class AMPOptimizer(ColossalaiOptimizer):
|
||||||
"""
|
"""
|
||||||
A wrapper for Optimizer.
|
A wrapper for Optimizer.
|
||||||
Code reference: https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/optimizer/zero_optimizer.py
|
Code reference: https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/optimizer/zero_optimizer.py
|
||||||
@ -174,4 +175,4 @@ class AMPOptimizer(ColossalaiOptimizer):
|
|||||||
|
|
||||||
# Leverage state_dict() and load_state_dict() to
|
# Leverage state_dict() and load_state_dict() to
|
||||||
# recast preexisting per-param state tensors
|
# recast preexisting per-param state tensors
|
||||||
self.optim.load_state_dict(self.optim.state_dict())
|
self.optim.load_state_dict(self.optim.state_dict())
|
||||||
|
Loading…
Reference in New Issue
Block a user