mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 17:17:05 +00:00
[nvme] CPUAdam and HybridAdam support NVMe offload (#1360)
* impl nvme optimizer * update cpu adam * add unit test * update hybrid adam * update docstr * add TODOs * update CI * fix CI * fix CI * fix CI path * fix CI path * fix CI path * fix install tensornvme * fix CI * fix CI path * fix CI env variables * test CI * test CI * fix CI * fix nvme optim __del__ * fix adam __del__ * fix nvme optim * fix CI env variables * fix nvme optim import * test CI * test CI * fix CI
This commit is contained in:
160
colossalai/nn/optimizer/nvme_optimizer.py
Normal file
160
colossalai/nn/optimizer/nvme_optimizer.py
Normal file
@@ -0,0 +1,160 @@
|
||||
import torch
|
||||
import os
|
||||
import tempfile
|
||||
import math
|
||||
from torch.nn.parameter import Parameter
|
||||
from typing import Optional, List, Dict, Callable
|
||||
|
||||
|
||||
class NVMeOptimizer(torch.optim.Optimizer):
|
||||
"""A base class for offloading optimizer states.
|
||||
|
||||
Args:
|
||||
params: parameters
|
||||
defaults (dict): default dict
|
||||
nvme_offload_fraction (float, optional): Fraction of params to be offloaded to NVMe. Defaults to 0.0.
|
||||
offload_dir (Optional[str], optional): Directory to save NVMe offload files.
|
||||
If it's ``None``, a random temporary directory will be used. Defaults to None.
|
||||
|
||||
Raises:
|
||||
ImportError: Raise if ``tensornvme`` is not installed.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
params,
|
||||
defaults: dict,
|
||||
nvme_offload_fraction: float = 0.0,
|
||||
offload_dir: Optional[str] = None) -> None:
|
||||
assert 0.0 <= nvme_offload_fraction <= 1.0
|
||||
super().__init__(params, defaults)
|
||||
self.nvme_offload_fraction = float(nvme_offload_fraction)
|
||||
if self.nvme_offload_fraction > 0.0:
|
||||
try:
|
||||
from tensornvme import DiskOffloader
|
||||
from tensornvme._C import get_backends
|
||||
except ModuleNotFoundError:
|
||||
raise ModuleNotFoundError('Please install tensornvme to use NVMeOptimizer')
|
||||
self.offload_dir = offload_dir or tempfile.mkdtemp()
|
||||
backend = 'uring' if 'uring' in get_backends() else 'aio'
|
||||
self.offloader = DiskOffloader(self.offload_dir, 8, backend=backend)
|
||||
else:
|
||||
self.offload_dir = None
|
||||
self.offloader = None
|
||||
self.is_on_nvme: Dict[Parameter, bool] = {}
|
||||
self.offloaded_numel: int = 0
|
||||
self.total_numel: int = self._get_numel()
|
||||
self.can_offload_numel = math.floor(self.total_numel * self.nvme_offload_fraction)
|
||||
|
||||
self.prefetch_params: List[Parameter] = []
|
||||
self.param_to_prefetch_idx: Dict[Parameter, int] = {}
|
||||
|
||||
def _get_numel(self) -> int:
|
||||
numel = 0
|
||||
for group in self.param_groups:
|
||||
for p in group['params']:
|
||||
numel += p.storage().size()
|
||||
return numel
|
||||
|
||||
def _post_state_init(self, param: Parameter) -> None:
|
||||
numel = param.storage().size()
|
||||
if self.offloader is not None and param.device.type == 'cpu' and numel + self.offloaded_numel <= self.can_offload_numel:
|
||||
self.is_on_nvme[param] = True
|
||||
self.offloaded_numel += numel
|
||||
else:
|
||||
self.is_on_nvme[param] = False
|
||||
|
||||
def _setup_prefetch_params(self) -> List[Parameter]:
|
||||
if self.offloader is None:
|
||||
return
|
||||
assert len(self.prefetch_params) == 0 and len(self.param_to_prefetch_idx) == 0
|
||||
for group in self.param_groups:
|
||||
for p in group['params']:
|
||||
if p.grad is None:
|
||||
continue
|
||||
if len(self.state[p]) > 0 and self.is_on_nvme[p]:
|
||||
assert p.device.type == 'cpu'
|
||||
self.param_to_prefetch_idx[p] = len(self.prefetch_params)
|
||||
self.prefetch_params.append(p)
|
||||
|
||||
def _pre_step(self, *state_keys: str) -> None:
|
||||
self._setup_prefetch_params()
|
||||
if self.offloader is None or len(self.prefetch_params) == 0:
|
||||
return
|
||||
state = self.state[self.prefetch_params[0]]
|
||||
for key in state_keys:
|
||||
self.offloader.async_read(state[key])
|
||||
|
||||
def _pre_update(self, param: Parameter, *state_keys: str) -> None:
|
||||
if self.offloader is None or param not in self.param_to_prefetch_idx:
|
||||
return
|
||||
self.offloader.sync_read_events()
|
||||
idx = self.param_to_prefetch_idx[param]
|
||||
if idx + 1 < len(self.prefetch_params):
|
||||
state = self.state[self.prefetch_params[idx + 1]]
|
||||
for key in state_keys:
|
||||
self.offloader.async_read(state[key])
|
||||
|
||||
def _post_update(self, param: Parameter, *state_keys: str) -> None:
|
||||
if self.offloader is None:
|
||||
return
|
||||
self.offloader.sync_write_events()
|
||||
if self.is_on_nvme[param]:
|
||||
state = self.state[param]
|
||||
for key in state_keys:
|
||||
self.offloader.async_write(state[key])
|
||||
|
||||
def _post_step(self) -> None:
|
||||
if self.offloader is not None:
|
||||
self.offloader.synchronize()
|
||||
self.prefetch_params.clear()
|
||||
self.param_to_prefetch_idx.clear()
|
||||
|
||||
def step(self, closure: Optional[Callable[[], float]] = ...) -> Optional[float]:
|
||||
"""Performs a single optimization step (parameter update).
|
||||
|
||||
Example:
|
||||
|
||||
>>> self._pre_step('exp_avg', 'exp_avg_sq')
|
||||
>>> for group in self.param_groups:
|
||||
>>> for p in group['params']:
|
||||
>>> if p.grad is None:
|
||||
>>> continue
|
||||
>>> state = self.state[p]
|
||||
>>> if len(state) == 0:
|
||||
>>> state['exp_avg'] = ...
|
||||
>>> state['exp_avg_sq'] = ...
|
||||
>>> self._post_state_init(p)
|
||||
>>> if p.device.type == 'cpu':
|
||||
>>> self._pre_update(p, 'exp_avg', 'exp_avg_sq')
|
||||
>>> adam()
|
||||
>>> self._post_update(p, 'exp_avg', 'exp_avg_sq')
|
||||
>>> else:
|
||||
>>> ...
|
||||
>>> self._post_step()
|
||||
|
||||
Args:
|
||||
closure (Optional[Callable[[], float]], optional): A closure that reevaluates the model and
|
||||
returns the loss. Optional for most optimizers.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def state_dict(self) -> dict:
|
||||
# TODO(ver217): design a new method to save state_dict. When using NVMe offload, this method may lead to OOM.
|
||||
if self.offloader is not None:
|
||||
raise NotImplementedError
|
||||
return super().state_dict()
|
||||
|
||||
def load_state_dict(self, state_dict: dict) -> None:
|
||||
# TODO(ver217): design a new method to load state_dict. When using NVMe offload, whole state_dict may not be able to fit in memory.
|
||||
if self.offloader is not None:
|
||||
raise NotImplementedError
|
||||
super().load_state_dict(state_dict)
|
||||
|
||||
def __del__(self) -> None:
|
||||
if getattr(self, 'offloader', None) is not None:
|
||||
del self.offloader
|
||||
if os.path.exists(self.offload_dir):
|
||||
try:
|
||||
os.rmdir(self.offload_dir)
|
||||
except OSError:
|
||||
pass
|
Reference in New Issue
Block a user