[nvme] CPUAdam and HybridAdam support NVMe offload (#1360)

* impl nvme optimizer

* update cpu adam

* add unit test

* update hybrid adam

* update docstr

* add TODOs

* update CI

* fix CI

* fix CI

* fix CI path

* fix CI path

* fix CI path

* fix install tensornvme

* fix CI

* fix CI path

* fix CI env variables

* test CI

* test CI

* fix CI

* fix nvme optim __del__

* fix adam __del__

* fix nvme optim

* fix CI env variables

* fix nvme optim import

* test CI

* test CI

* fix CI
This commit is contained in:
ver217
2022-07-26 17:25:24 +08:00
committed by GitHub
parent 8463290642
commit c415240db6
6 changed files with 264 additions and 8 deletions

View File

@@ -3,10 +3,12 @@ import torch
from colossalai.registry import OPTIMIZERS
from colossalai.nn.optimizer import CPU_ADAM_CNT
from .nvme_optimizer import NVMeOptimizer
from typing import Optional
@OPTIMIZERS.register_module
class CPUAdam(torch.optim.Optimizer):
class CPUAdam(NVMeOptimizer):
"""Implements Adam algorithm.
Supports parameters updating on both GPU and CPU, depanding on the device of paramters.
@@ -45,6 +47,9 @@ class CPUAdam(torch.optim.Optimizer):
True for decoupled weight decay(also known as AdamW) (default: True)
simd_log (boolean, optional): whether to show if you are using SIMD to
accelerate. (default: False)
nvme_offload_fraction (float, optional): Fraction of params to be offloaded to NVMe. Defaults to 0.0.
offload_dir (Optional[str], optional): Directory to save NVMe offload files.
If it's ``None``, a random temporary directory will be used. Defaults to None.
.. _Adam\: A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980
@@ -64,10 +69,12 @@ class CPUAdam(torch.optim.Optimizer):
eps=1e-8,
weight_decay=0,
adamw_mode=True,
simd_log=False):
simd_log=False,
nvme_offload_fraction: float = 0.0,
nvme_offload_dir: Optional[str] = None):
default_args = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, bias_correction=bias_correction)
super(CPUAdam, self).__init__(model_params, default_args)
super(CPUAdam, self).__init__(model_params, default_args, nvme_offload_fraction, nvme_offload_dir)
self.opt_id = CPU_ADAM_CNT()
self.adamw_mode = adamw_mode
try:
@@ -78,7 +85,8 @@ class CPUAdam(torch.optim.Optimizer):
self.cpu_adam_op.create_adam(self.opt_id, lr, betas[0], betas[1], eps, weight_decay, adamw_mode, simd_log)
def __del__(self):
if self.cpu_adam_op:
super().__del__()
if getattr(self, 'cpu_adam_op', None):
self.cpu_adam_op.destroy_adam(self.opt_id)
def torch_adam_update(self,
@@ -121,6 +129,7 @@ class CPUAdam(torch.optim.Optimizer):
with torch.enable_grad():
loss = closure()
self._pre_step('exp_avg', 'exp_avg_sq')
for _, group in enumerate(self.param_groups):
for _, p in enumerate(group['params']):
@@ -137,6 +146,7 @@ class CPUAdam(torch.optim.Optimizer):
state['exp_avg'] = torch.zeros_like(p.data, dtype=torch.float, device=target_device)
# gradient variances
state['exp_avg_sq'] = torch.zeros_like(p.data, dtype=torch.float, device=target_device)
self._post_state_init(p)
state['step'] += 1
beta1, beta2 = group['betas']
@@ -145,9 +155,11 @@ class CPUAdam(torch.optim.Optimizer):
assert p.data.numel() == p.grad.data.numel(), "parameter and gradient should have the same size"
assert state['exp_avg'].device.type == 'cpu', "exp_avg should stay on cpu"
assert state['exp_avg_sq'].device.type == 'cpu', "exp_avg should stay on cpu"
self._pre_update(p, 'exp_avg', 'exp_avg_sq')
self.cpu_adam_op.adam_update(self.opt_id, state['step'], group['lr'], beta1, beta2, group['eps'],
group['weight_decay'], group['bias_correction'], p.data, p.grad.data,
state['exp_avg'], state['exp_avg_sq'], -1)
self._post_update(p, 'exp_avg', 'exp_avg_sq')
elif target_device.type == 'cuda':
assert state['exp_avg'].device.type == 'cuda', "exp_avg should stay on cuda"
assert state['exp_avg_sq'].device.type == 'cuda', "exp_avg should stay on cuda"
@@ -161,4 +173,5 @@ class CPUAdam(torch.optim.Optimizer):
bias_correction2, self.adamw_mode)
else:
raise RuntimeError
self._post_step()
return loss