mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 11:02:05 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -3,7 +3,7 @@
|
||||
## Introduction
|
||||
|
||||
Welcome to the large-scale deep learning optimization techniques of [Colossal-AI](https://github.com/hpcaitech/ColossalAI),
|
||||
which has been accepted as official tutorials by top conference [NeurIPS](https://nips.cc/), [SC](https://sc22.supercomputing.org/), [AAAI](https://aaai.org/Conferences/AAAI-23/),
|
||||
which has been accepted as official tutorials by top conference [NeurIPS](https://nips.cc/), [SC](https://sc22.supercomputing.org/), [AAAI](https://aaai.org/Conferences/AAAI-23/),
|
||||
[PPoPP](https://ppopp23.sigplan.org/), [CVPR](https://cvpr2023.thecvf.com/), [ISC](https://www.isc-hpc.com/), [NVIDIA GTC](https://www.nvidia.com/en-us/on-demand/session/gtcspring23-S51482/) ,etc.
|
||||
|
||||
|
||||
|
@@ -6,4 +6,4 @@ from .hybrid_adam import HybridAdam
|
||||
from .lamb import Lamb
|
||||
from .lars import Lars
|
||||
|
||||
__all__ = ['FusedLAMB', 'FusedAdam', 'FusedSGD', 'Lamb', 'Lars', 'CPUAdam', 'HybridAdam']
|
||||
__all__ = ["FusedLAMB", "FusedAdam", "FusedSGD", "Lamb", "Lars", "CPUAdam", "HybridAdam"]
|
||||
|
@@ -61,36 +61,39 @@ class CPUAdam(NVMeOptimizer):
|
||||
# Param weight, grad, momentum and variance
|
||||
num_fp32_shards_per_param = 4
|
||||
|
||||
def __init__(self,
|
||||
model_params,
|
||||
lr=1e-3,
|
||||
bias_correction=True,
|
||||
betas=(0.9, 0.999),
|
||||
eps=1e-8,
|
||||
weight_decay=0,
|
||||
adamw_mode=True,
|
||||
nvme_offload_fraction: float = 0.0,
|
||||
nvme_offload_dir: Optional[str] = None):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_params,
|
||||
lr=1e-3,
|
||||
bias_correction=True,
|
||||
betas=(0.9, 0.999),
|
||||
eps=1e-8,
|
||||
weight_decay=0,
|
||||
adamw_mode=True,
|
||||
nvme_offload_fraction: float = 0.0,
|
||||
nvme_offload_dir: Optional[str] = None,
|
||||
):
|
||||
default_args = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, bias_correction=bias_correction)
|
||||
super(CPUAdam, self).__init__(model_params, default_args, nvme_offload_fraction, nvme_offload_dir)
|
||||
self.adamw_mode = adamw_mode
|
||||
cpu_adam = CPUAdamBuilder().load()
|
||||
self.cpu_adam_op = cpu_adam.CPUAdamOptimizer(lr, betas[0], betas[1], eps, weight_decay, adamw_mode)
|
||||
|
||||
def torch_adam_update(self,
|
||||
data,
|
||||
grad,
|
||||
exp_avg,
|
||||
exp_avg_sq,
|
||||
lr,
|
||||
beta1,
|
||||
beta2,
|
||||
eps,
|
||||
weight_decay,
|
||||
bias_correction1,
|
||||
bias_correction2,
|
||||
use_adamw=False):
|
||||
def torch_adam_update(
|
||||
self,
|
||||
data,
|
||||
grad,
|
||||
exp_avg,
|
||||
exp_avg_sq,
|
||||
lr,
|
||||
beta1,
|
||||
beta2,
|
||||
eps,
|
||||
weight_decay,
|
||||
bias_correction1,
|
||||
bias_correction2,
|
||||
use_adamw=False,
|
||||
):
|
||||
grad = grad.to(data.dtype)
|
||||
|
||||
if weight_decay != 0:
|
||||
@@ -117,10 +120,9 @@ class CPUAdam(NVMeOptimizer):
|
||||
with torch.enable_grad():
|
||||
loss = closure()
|
||||
|
||||
self._pre_step('exp_avg', 'exp_avg_sq')
|
||||
self._pre_step("exp_avg", "exp_avg_sq")
|
||||
for _, group in enumerate(self.param_groups):
|
||||
for _, p in enumerate(group['params']):
|
||||
|
||||
for _, p in enumerate(group["params"]):
|
||||
if p.grad is None:
|
||||
continue
|
||||
|
||||
@@ -128,48 +130,81 @@ class CPUAdam(NVMeOptimizer):
|
||||
|
||||
target_device = p.device
|
||||
if len(state) == 0:
|
||||
state['step'] = 0
|
||||
state["step"] = 0
|
||||
|
||||
# FIXME(ver217): CPU adam kernel only supports fp32 states now
|
||||
assert p.dtype is torch.float, "CPUAdam only support fp32 parameters"
|
||||
# gradient momentums
|
||||
state['exp_avg'] = torch.zeros_like(p, device=target_device)
|
||||
state["exp_avg"] = torch.zeros_like(p, device=target_device)
|
||||
# gradient variances
|
||||
state['exp_avg_sq'] = torch.zeros_like(p, device=target_device)
|
||||
state["exp_avg_sq"] = torch.zeros_like(p, device=target_device)
|
||||
self._post_state_init(p)
|
||||
|
||||
state['step'] += 1
|
||||
beta1, beta2 = group['betas']
|
||||
state["step"] += 1
|
||||
beta1, beta2 = group["betas"]
|
||||
|
||||
if target_device.type == 'cpu':
|
||||
if target_device.type == "cpu":
|
||||
assert p.data.numel() == p.grad.data.numel(), "parameter and gradient should have the same size"
|
||||
assert state['exp_avg'].device.type == 'cpu', "exp_avg should stay on cpu"
|
||||
assert state['exp_avg_sq'].device.type == 'cpu', "exp_avg should stay on cpu"
|
||||
self._pre_update(p, 'exp_avg', 'exp_avg_sq')
|
||||
assert state["exp_avg"].device.type == "cpu", "exp_avg should stay on cpu"
|
||||
assert state["exp_avg_sq"].device.type == "cpu", "exp_avg should stay on cpu"
|
||||
self._pre_update(p, "exp_avg", "exp_avg_sq")
|
||||
if p.grad.dtype is torch.bfloat16:
|
||||
# cpu adam kernel does not support bf16 now
|
||||
bias_correction1 = 1 - beta1**state['step']
|
||||
bias_correction2 = 1 - beta2**state['step']
|
||||
self.torch_adam_update(p.data, p.grad.data, state['exp_avg'], state['exp_avg_sq'], group['lr'],
|
||||
beta1, beta2, group['eps'], group['weight_decay'], bias_correction1,
|
||||
bias_correction2, self.adamw_mode)
|
||||
bias_correction1 = 1 - beta1 ** state["step"]
|
||||
bias_correction2 = 1 - beta2 ** state["step"]
|
||||
self.torch_adam_update(
|
||||
p.data,
|
||||
p.grad.data,
|
||||
state["exp_avg"],
|
||||
state["exp_avg_sq"],
|
||||
group["lr"],
|
||||
beta1,
|
||||
beta2,
|
||||
group["eps"],
|
||||
group["weight_decay"],
|
||||
bias_correction1,
|
||||
bias_correction2,
|
||||
self.adamw_mode,
|
||||
)
|
||||
else:
|
||||
self.cpu_adam_op.step(state['step'], group['lr'], beta1, beta2, group['eps'],
|
||||
group['weight_decay'], group['bias_correction'], p.data, p.grad.data,
|
||||
state['exp_avg'], state['exp_avg_sq'], div_scale)
|
||||
self._post_update(p, 'exp_avg', 'exp_avg_sq')
|
||||
elif target_device.type == 'cuda':
|
||||
self.cpu_adam_op.step(
|
||||
state["step"],
|
||||
group["lr"],
|
||||
beta1,
|
||||
beta2,
|
||||
group["eps"],
|
||||
group["weight_decay"],
|
||||
group["bias_correction"],
|
||||
p.data,
|
||||
p.grad.data,
|
||||
state["exp_avg"],
|
||||
state["exp_avg_sq"],
|
||||
div_scale,
|
||||
)
|
||||
self._post_update(p, "exp_avg", "exp_avg_sq")
|
||||
elif target_device.type == "cuda":
|
||||
assert div_scale == -1, "div_scale should remain default"
|
||||
assert state['exp_avg'].device.type == 'cuda', "exp_avg should stay on cuda"
|
||||
assert state['exp_avg_sq'].device.type == 'cuda', "exp_avg should stay on cuda"
|
||||
assert state["exp_avg"].device.type == "cuda", "exp_avg should stay on cuda"
|
||||
assert state["exp_avg_sq"].device.type == "cuda", "exp_avg should stay on cuda"
|
||||
|
||||
bias_correction1 = 1 - beta1**state['step']
|
||||
bias_correction2 = 1 - beta2**state['step']
|
||||
bias_correction1 = 1 - beta1 ** state["step"]
|
||||
bias_correction2 = 1 - beta2 ** state["step"]
|
||||
|
||||
# adam on cuda
|
||||
self.torch_adam_update(p.data, p.grad.data, state['exp_avg'], state['exp_avg_sq'], group['lr'],
|
||||
beta1, beta2, group['eps'], group['weight_decay'], bias_correction1,
|
||||
bias_correction2, self.adamw_mode)
|
||||
self.torch_adam_update(
|
||||
p.data,
|
||||
p.grad.data,
|
||||
state["exp_avg"],
|
||||
state["exp_avg_sq"],
|
||||
group["lr"],
|
||||
beta1,
|
||||
beta2,
|
||||
group["eps"],
|
||||
group["weight_decay"],
|
||||
bias_correction1,
|
||||
bias_correction2,
|
||||
self.adamw_mode,
|
||||
)
|
||||
else:
|
||||
raise RuntimeError
|
||||
self._post_step()
|
||||
|
@@ -1,11 +1,11 @@
|
||||
# modified from https://github.com/NVIDIA/apex/blob/master/apex/optimizers/fused_adam.py
|
||||
'''
|
||||
"""
|
||||
Copyright 2020 The Microsoft DeepSpeed Team
|
||||
|
||||
Copyright NVIDIA/apex
|
||||
This file is adapted from fused adam in NVIDIA/apex, commit a109f85
|
||||
Licensed under the MIT License.
|
||||
'''
|
||||
"""
|
||||
import torch
|
||||
|
||||
from colossalai.utils import multi_tensor_applier
|
||||
@@ -51,37 +51,39 @@ class FusedAdam(torch.optim.Optimizer):
|
||||
https://openreview.net/forum?id=ryQu7f-RZ
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
params,
|
||||
lr=1e-3,
|
||||
bias_correction=True,
|
||||
betas=(0.9, 0.999),
|
||||
eps=1e-8,
|
||||
adamw_mode=True,
|
||||
weight_decay=0.,
|
||||
amsgrad=False,
|
||||
set_grad_none=True):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
params,
|
||||
lr=1e-3,
|
||||
bias_correction=True,
|
||||
betas=(0.9, 0.999),
|
||||
eps=1e-8,
|
||||
adamw_mode=True,
|
||||
weight_decay=0.0,
|
||||
amsgrad=False,
|
||||
set_grad_none=True,
|
||||
):
|
||||
if amsgrad:
|
||||
raise RuntimeError('FusedAdam does not support the AMSGrad variant.')
|
||||
raise RuntimeError("FusedAdam does not support the AMSGrad variant.")
|
||||
defaults = dict(lr=lr, bias_correction=bias_correction, betas=betas, eps=eps, weight_decay=weight_decay)
|
||||
super(FusedAdam, self).__init__(params, defaults)
|
||||
self.adamw_mode = 1 if adamw_mode else 0
|
||||
self.set_grad_none = set_grad_none
|
||||
if multi_tensor_applier.available:
|
||||
from colossalai.kernel.op_builder import FusedOptimBuilder
|
||||
|
||||
fused_optim = FusedOptimBuilder().load()
|
||||
|
||||
# Skip buffer
|
||||
self._dummy_overflow_buf = torch.cuda.IntTensor([0])
|
||||
self.multi_tensor_adam = fused_optim.multi_tensor_adam
|
||||
else:
|
||||
raise RuntimeError('FusedAdam requires cuda extensions')
|
||||
raise RuntimeError("FusedAdam requires cuda extensions")
|
||||
|
||||
def zero_grad(self, set_to_none=False):
|
||||
if set_to_none:
|
||||
for group in self.param_groups:
|
||||
for p in group['params']:
|
||||
for p in group["params"]:
|
||||
p.grad = None
|
||||
else:
|
||||
super(FusedAdam, self).zero_grad()
|
||||
@@ -97,51 +99,63 @@ class FusedAdam(torch.optim.Optimizer):
|
||||
"""
|
||||
if any(p is not None for p in [grads, output_params, scale, grad_norms]):
|
||||
raise RuntimeError(
|
||||
'FusedAdam has been updated. Simply initialize it identically to torch.optim.Adam, and call step() with no arguments.'
|
||||
"FusedAdam has been updated. Simply initialize it identically to torch.optim.Adam, and call step() with no arguments."
|
||||
)
|
||||
loss = None
|
||||
if closure is not None:
|
||||
loss = closure()
|
||||
|
||||
for group in self.param_groups:
|
||||
bias_correction = 1 if group['bias_correction'] else 0
|
||||
beta1, beta2 = group['betas']
|
||||
bias_correction = 1 if group["bias_correction"] else 0
|
||||
beta1, beta2 = group["betas"]
|
||||
|
||||
# assume same step across group now to simplify things
|
||||
# per parameter step can be easily support by making it tensor, or pass list into kernel
|
||||
if 'step' in group:
|
||||
group['step'] += 1
|
||||
if "step" in group:
|
||||
group["step"] += 1
|
||||
else:
|
||||
group['step'] = 1
|
||||
group["step"] = 1
|
||||
|
||||
# create lists for multi-tensor apply
|
||||
g_l, p_l, m_l, v_l = [], [], [], []
|
||||
|
||||
for p in group['params']:
|
||||
for p in group["params"]:
|
||||
if p.grad is None:
|
||||
continue
|
||||
if p.grad.data.is_sparse:
|
||||
raise RuntimeError(
|
||||
'FusedAdam does not support sparse gradients, please consider SparseAdam instead')
|
||||
"FusedAdam does not support sparse gradients, please consider SparseAdam instead"
|
||||
)
|
||||
|
||||
state = self.state[p]
|
||||
# State initialization
|
||||
if len(state) == 0:
|
||||
# Exponential moving average of gradient values
|
||||
state['exp_avg'] = torch.zeros_like(p)
|
||||
state["exp_avg"] = torch.zeros_like(p)
|
||||
# Exponential moving average of squared gradient values
|
||||
state['exp_avg_sq'] = torch.zeros_like(p)
|
||||
state["exp_avg_sq"] = torch.zeros_like(p)
|
||||
|
||||
if p.dtype not in [torch.float16, torch.float32, torch.bfloat16]:
|
||||
raise RuntimeError('FusedAdam only support fp16, fp32 and bf16.')
|
||||
raise RuntimeError("FusedAdam only support fp16, fp32 and bf16.")
|
||||
|
||||
g_l.append(p.grad.data)
|
||||
p_l.append(p.data)
|
||||
m_l.append(state['exp_avg'])
|
||||
v_l.append(state['exp_avg_sq'])
|
||||
m_l.append(state["exp_avg"])
|
||||
v_l.append(state["exp_avg_sq"])
|
||||
|
||||
multi_tensor_applier(self.multi_tensor_adam, self._dummy_overflow_buf, [g_l, p_l, m_l, v_l], group['lr'],
|
||||
beta1, beta2, group['eps'], group['step'], self.adamw_mode, bias_correction,
|
||||
group['weight_decay'], div_scale)
|
||||
multi_tensor_applier(
|
||||
self.multi_tensor_adam,
|
||||
self._dummy_overflow_buf,
|
||||
[g_l, p_l, m_l, v_l],
|
||||
group["lr"],
|
||||
beta1,
|
||||
beta2,
|
||||
group["eps"],
|
||||
group["step"],
|
||||
self.adamw_mode,
|
||||
bias_correction,
|
||||
group["weight_decay"],
|
||||
div_scale,
|
||||
)
|
||||
|
||||
return loss
|
||||
|
@@ -49,41 +49,46 @@ class FusedLAMB(torch.optim.Optimizer):
|
||||
https://openreview.net/forum?id=ryQu7f-RZ
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
params,
|
||||
lr=1e-3,
|
||||
bias_correction=True,
|
||||
betas=(0.9, 0.999),
|
||||
eps=1e-6,
|
||||
weight_decay=0.01,
|
||||
amsgrad=False,
|
||||
adam_w_mode=True,
|
||||
grad_averaging=True,
|
||||
set_grad_none=True,
|
||||
max_grad_norm=1.0,
|
||||
use_nvlamb=False):
|
||||
def __init__(
|
||||
self,
|
||||
params,
|
||||
lr=1e-3,
|
||||
bias_correction=True,
|
||||
betas=(0.9, 0.999),
|
||||
eps=1e-6,
|
||||
weight_decay=0.01,
|
||||
amsgrad=False,
|
||||
adam_w_mode=True,
|
||||
grad_averaging=True,
|
||||
set_grad_none=True,
|
||||
max_grad_norm=1.0,
|
||||
use_nvlamb=False,
|
||||
):
|
||||
if amsgrad:
|
||||
raise RuntimeError('FusedLAMB does not support the AMSGrad variant.')
|
||||
defaults = dict(lr=lr,
|
||||
bias_correction=bias_correction,
|
||||
betas=betas,
|
||||
eps=eps,
|
||||
weight_decay=weight_decay,
|
||||
grad_averaging=grad_averaging,
|
||||
max_grad_norm=max_grad_norm)
|
||||
raise RuntimeError("FusedLAMB does not support the AMSGrad variant.")
|
||||
defaults = dict(
|
||||
lr=lr,
|
||||
bias_correction=bias_correction,
|
||||
betas=betas,
|
||||
eps=eps,
|
||||
weight_decay=weight_decay,
|
||||
grad_averaging=grad_averaging,
|
||||
max_grad_norm=max_grad_norm,
|
||||
)
|
||||
super(FusedLAMB, self).__init__(params, defaults)
|
||||
if multi_tensor_applier.available:
|
||||
from colossalai.kernel.op_builder import FusedOptimBuilder
|
||||
|
||||
fused_optim = FusedOptimBuilder().load()
|
||||
|
||||
self.multi_tensor_l2norm = fused_optim.multi_tensor_l2norm
|
||||
# Skip buffer
|
||||
self._dummy_overflow_buf = torch.tensor([0],
|
||||
dtype=torch.int,
|
||||
device=self.param_groups[0]["params"][0].device)
|
||||
self._dummy_overflow_buf = torch.tensor(
|
||||
[0], dtype=torch.int, device=self.param_groups[0]["params"][0].device
|
||||
)
|
||||
self.multi_tensor_lamb = fused_optim.multi_tensor_lamb
|
||||
else:
|
||||
raise RuntimeError('FusedLAMB requires cuda extensions')
|
||||
raise RuntimeError("FusedLAMB requires cuda extensions")
|
||||
|
||||
self.adam_w_mode = 1 if adam_w_mode else 0
|
||||
self.set_grad_none = set_grad_none
|
||||
@@ -92,7 +97,7 @@ class FusedLAMB(torch.optim.Optimizer):
|
||||
def zero_grad(self):
|
||||
if self.set_grad_none:
|
||||
for group in self.param_groups:
|
||||
for p in group['params']:
|
||||
for p in group["params"]:
|
||||
p.grad = None
|
||||
else:
|
||||
super(FusedLAMB, self).zero_grad()
|
||||
@@ -111,7 +116,7 @@ class FusedLAMB(torch.optim.Optimizer):
|
||||
# create separate grad lists for fp32 and fp16 params
|
||||
g_all_32, g_all_16 = [], []
|
||||
for group in self.param_groups:
|
||||
for p in group['params']:
|
||||
for p in group["params"]:
|
||||
if p.grad is None:
|
||||
continue
|
||||
if p.dtype == torch.float32:
|
||||
@@ -119,7 +124,7 @@ class FusedLAMB(torch.optim.Optimizer):
|
||||
elif p.dtype == torch.float16:
|
||||
g_all_16.append(p.grad.data)
|
||||
else:
|
||||
raise RuntimeError('FusedLAMB only support fp16 and fp32.')
|
||||
raise RuntimeError("FusedLAMB only support fp16 and fp32.")
|
||||
|
||||
device = self.param_groups[0]["params"][0].device
|
||||
g_norm_32, g_norm_16 = torch.zeros(1, device=device), torch.zeros(1, device=device)
|
||||
@@ -130,63 +135,91 @@ class FusedLAMB(torch.optim.Optimizer):
|
||||
g_norm_16 = multi_tensor_applier(self.multi_tensor_l2norm, self._dummy_overflow_buf, [g_all_16], False)[0]
|
||||
|
||||
# blend two grad norms to get global grad norm
|
||||
global_grad_norm = multi_tensor_applier(self.multi_tensor_l2norm, self._dummy_overflow_buf,
|
||||
[[g_norm_32, g_norm_16]], False)[0]
|
||||
max_grad_norm = self.defaults['max_grad_norm']
|
||||
global_grad_norm = multi_tensor_applier(
|
||||
self.multi_tensor_l2norm, self._dummy_overflow_buf, [[g_norm_32, g_norm_16]], False
|
||||
)[0]
|
||||
max_grad_norm = self.defaults["max_grad_norm"]
|
||||
|
||||
for group in self.param_groups:
|
||||
bias_correction = 1 if group['bias_correction'] else 0
|
||||
beta1, beta2 = group['betas']
|
||||
grad_averaging = 1 if group['grad_averaging'] else 0
|
||||
bias_correction = 1 if group["bias_correction"] else 0
|
||||
beta1, beta2 = group["betas"]
|
||||
grad_averaging = 1 if group["grad_averaging"] else 0
|
||||
|
||||
# assume same step across group now to simplify things
|
||||
# per parameter step can be easily support by making it tensor, or pass list into kernel
|
||||
if 'step' in group:
|
||||
group['step'] += 1
|
||||
if "step" in group:
|
||||
group["step"] += 1
|
||||
else:
|
||||
group['step'] = 1
|
||||
group["step"] = 1
|
||||
|
||||
# create lists for multi-tensor apply
|
||||
g_16, p_16, m_16, v_16 = [], [], [], []
|
||||
g_32, p_32, m_32, v_32 = [], [], [], []
|
||||
|
||||
for p in group['params']:
|
||||
for p in group["params"]:
|
||||
if p.grad is None:
|
||||
continue
|
||||
if p.grad.data.is_sparse:
|
||||
raise RuntimeError(
|
||||
'FusedLAMB does not support sparse gradients, please consider SparseAdam instead')
|
||||
"FusedLAMB does not support sparse gradients, please consider SparseAdam instead"
|
||||
)
|
||||
|
||||
state = self.state[p]
|
||||
# State initialization
|
||||
if len(state) == 0:
|
||||
# Exponential moving average of gradient values
|
||||
state['exp_avg'] = torch.zeros_like(p)
|
||||
state["exp_avg"] = torch.zeros_like(p)
|
||||
# Exponential moving average of gradient values
|
||||
state['exp_avg_sq'] = torch.zeros_like(p)
|
||||
state["exp_avg_sq"] = torch.zeros_like(p)
|
||||
|
||||
if p.dtype == torch.float16:
|
||||
g_16.append(p.grad.data)
|
||||
p_16.append(p.data)
|
||||
m_16.append(state['exp_avg'])
|
||||
v_16.append(state['exp_avg_sq'])
|
||||
m_16.append(state["exp_avg"])
|
||||
v_16.append(state["exp_avg_sq"])
|
||||
elif p.dtype == torch.float32:
|
||||
g_32.append(p.grad.data)
|
||||
p_32.append(p.data)
|
||||
m_32.append(state['exp_avg'])
|
||||
v_32.append(state['exp_avg_sq'])
|
||||
m_32.append(state["exp_avg"])
|
||||
v_32.append(state["exp_avg_sq"])
|
||||
else:
|
||||
raise RuntimeError('FusedLAMB only support fp16 and fp32.')
|
||||
raise RuntimeError("FusedLAMB only support fp16 and fp32.")
|
||||
|
||||
if (len(g_16) > 0):
|
||||
multi_tensor_applier(self.multi_tensor_lamb, self._dummy_overflow_buf, [g_16, p_16, m_16, v_16],
|
||||
group['lr'], beta1, beta2, group['eps'], group['step'], bias_correction,
|
||||
group['weight_decay'], grad_averaging, self.adam_w_mode, global_grad_norm,
|
||||
max_grad_norm, self.use_nvlamb)
|
||||
if (len(g_32) > 0):
|
||||
multi_tensor_applier(self.multi_tensor_lamb, self._dummy_overflow_buf, [g_32, p_32, m_32, v_32],
|
||||
group['lr'], beta1, beta2, group['eps'], group['step'], bias_correction,
|
||||
group['weight_decay'], grad_averaging, self.adam_w_mode, global_grad_norm,
|
||||
max_grad_norm, self.use_nvlamb)
|
||||
if len(g_16) > 0:
|
||||
multi_tensor_applier(
|
||||
self.multi_tensor_lamb,
|
||||
self._dummy_overflow_buf,
|
||||
[g_16, p_16, m_16, v_16],
|
||||
group["lr"],
|
||||
beta1,
|
||||
beta2,
|
||||
group["eps"],
|
||||
group["step"],
|
||||
bias_correction,
|
||||
group["weight_decay"],
|
||||
grad_averaging,
|
||||
self.adam_w_mode,
|
||||
global_grad_norm,
|
||||
max_grad_norm,
|
||||
self.use_nvlamb,
|
||||
)
|
||||
if len(g_32) > 0:
|
||||
multi_tensor_applier(
|
||||
self.multi_tensor_lamb,
|
||||
self._dummy_overflow_buf,
|
||||
[g_32, p_32, m_32, v_32],
|
||||
group["lr"],
|
||||
beta1,
|
||||
beta2,
|
||||
group["eps"],
|
||||
group["step"],
|
||||
bias_correction,
|
||||
group["weight_decay"],
|
||||
grad_averaging,
|
||||
self.adam_w_mode,
|
||||
global_grad_norm,
|
||||
max_grad_norm,
|
||||
self.use_nvlamb,
|
||||
)
|
||||
|
||||
return loss
|
||||
|
@@ -54,14 +54,9 @@ class FusedSGD(Optimizer):
|
||||
The Nesterov version is analogously modified.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
params,
|
||||
lr=required,
|
||||
momentum=0,
|
||||
dampening=0,
|
||||
weight_decay=0,
|
||||
nesterov=False,
|
||||
wd_after_momentum=False):
|
||||
def __init__(
|
||||
self, params, lr=required, momentum=0, dampening=0, weight_decay=0, nesterov=False, wd_after_momentum=False
|
||||
):
|
||||
if lr is not required and lr < 0.0:
|
||||
raise ValueError("Invalid learning rate: {}".format(lr))
|
||||
if momentum < 0.0:
|
||||
@@ -78,20 +73,21 @@ class FusedSGD(Optimizer):
|
||||
|
||||
if multi_tensor_applier.available:
|
||||
from colossalai.kernel.op_builder import FusedOptimBuilder
|
||||
|
||||
fused_optim = FusedOptimBuilder().load()
|
||||
|
||||
# Skip buffer
|
||||
self._dummy_overflow_buf = torch.tensor([0],
|
||||
dtype=torch.int,
|
||||
device=self.param_groups[0]["params"][0].device)
|
||||
self._dummy_overflow_buf = torch.tensor(
|
||||
[0], dtype=torch.int, device=self.param_groups[0]["params"][0].device
|
||||
)
|
||||
self.multi_tensor_sgd = fused_optim.multi_tensor_sgd
|
||||
else:
|
||||
raise RuntimeError('FusedSGD requires cuda extensions')
|
||||
raise RuntimeError("FusedSGD requires cuda extensions")
|
||||
|
||||
def __setstate__(self, state):
|
||||
super(FusedSGD, self).__setstate__(state)
|
||||
for group in self.param_groups:
|
||||
group.setdefault('nesterov', False)
|
||||
group.setdefault("nesterov", False)
|
||||
|
||||
def get_momentums(self, params):
|
||||
momentums = []
|
||||
@@ -101,13 +97,13 @@ class FusedSGD(Optimizer):
|
||||
# torch.optim.SGD initializes momentum in the main loop, we have
|
||||
# to do it here, and track whether or not we've done so, so that
|
||||
# momentum application can be skipped in the main kernel.
|
||||
if 'momentum_buffer' not in param_state:
|
||||
if "momentum_buffer" not in param_state:
|
||||
first_run = True
|
||||
buf = param_state['momentum_buffer'] = torch.zeros_like(p)
|
||||
buf = param_state["momentum_buffer"] = torch.zeros_like(p)
|
||||
momentums.append(buf)
|
||||
else:
|
||||
first_run = False
|
||||
momentums.append(param_state['momentum_buffer'])
|
||||
momentums.append(param_state["momentum_buffer"])
|
||||
return momentums, first_run
|
||||
|
||||
def step(self, closure=None):
|
||||
@@ -122,10 +118,10 @@ class FusedSGD(Optimizer):
|
||||
loss = closure()
|
||||
|
||||
for group in self.param_groups:
|
||||
weight_decay = group['weight_decay']
|
||||
momentum = group['momentum']
|
||||
dampening = group['dampening']
|
||||
nesterov = group['nesterov']
|
||||
weight_decay = group["weight_decay"]
|
||||
momentum = group["momentum"]
|
||||
dampening = group["dampening"]
|
||||
nesterov = group["nesterov"]
|
||||
|
||||
# For each group, there are 3 possible combinations we need to consider:
|
||||
# grad_type, param_to_update_type, momentum_type
|
||||
@@ -133,15 +129,26 @@ class FusedSGD(Optimizer):
|
||||
# 2. fp32, fp32, fp32
|
||||
# 3. fp16, fp32, fp32
|
||||
g_l, p_l = [], []
|
||||
for p in group['params']:
|
||||
for p in group["params"]:
|
||||
if p.grad is None:
|
||||
continue
|
||||
if p.grad.data.is_sparse:
|
||||
raise RuntimeError('FusedSGD does not support sparse gradients')
|
||||
raise RuntimeError("FusedSGD does not support sparse gradients")
|
||||
g_l.append(p.grad)
|
||||
p_l.append(p)
|
||||
m_l, first_run = self.get_momentums(p_l)
|
||||
multi_tensor_applier(self.multi_tensor_sgd, self._dummy_overflow_buf, [g_l, p_l, m_l], weight_decay,
|
||||
momentum, dampening, group['lr'], nesterov, first_run, self.wd_after_momentum, 1.0)
|
||||
multi_tensor_applier(
|
||||
self.multi_tensor_sgd,
|
||||
self._dummy_overflow_buf,
|
||||
[g_l, p_l, m_l],
|
||||
weight_decay,
|
||||
momentum,
|
||||
dampening,
|
||||
group["lr"],
|
||||
nesterov,
|
||||
first_run,
|
||||
self.wd_after_momentum,
|
||||
1.0,
|
||||
)
|
||||
|
||||
return loss
|
||||
|
@@ -1,7 +1,6 @@
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from torch.optim import Adam
|
||||
|
||||
from colossalai.kernel.op_builder import FusedOptimBuilder
|
||||
from colossalai.utils import multi_tensor_applier
|
||||
@@ -61,20 +60,30 @@ class HybridAdam(CPUAdam):
|
||||
# Param weight, grad, momentum and variance
|
||||
num_fp32_shards_per_param = 4
|
||||
|
||||
def __init__(self,
|
||||
model_params,
|
||||
lr=1e-3,
|
||||
bias_correction=True,
|
||||
betas=(0.9, 0.999),
|
||||
eps=1e-8,
|
||||
weight_decay=0,
|
||||
adamw_mode=True,
|
||||
nvme_offload_fraction: float = 0.0,
|
||||
nvme_offload_dir: Optional[str] = None,
|
||||
**defaults: Any):
|
||||
|
||||
super().__init__(model_params, lr, bias_correction, betas, eps, weight_decay, adamw_mode, nvme_offload_fraction,
|
||||
nvme_offload_dir)
|
||||
def __init__(
|
||||
self,
|
||||
model_params,
|
||||
lr=1e-3,
|
||||
bias_correction=True,
|
||||
betas=(0.9, 0.999),
|
||||
eps=1e-8,
|
||||
weight_decay=0,
|
||||
adamw_mode=True,
|
||||
nvme_offload_fraction: float = 0.0,
|
||||
nvme_offload_dir: Optional[str] = None,
|
||||
**defaults: Any,
|
||||
):
|
||||
super().__init__(
|
||||
model_params,
|
||||
lr,
|
||||
bias_correction,
|
||||
betas,
|
||||
eps,
|
||||
weight_decay,
|
||||
adamw_mode,
|
||||
nvme_offload_fraction,
|
||||
nvme_offload_dir,
|
||||
)
|
||||
fused_optim = FusedOptimBuilder().load()
|
||||
self.gpu_adam_op = fused_optim.multi_tensor_adam
|
||||
self._dummy_overflow_buf = torch.cuda.IntTensor([0])
|
||||
@@ -86,12 +95,11 @@ class HybridAdam(CPUAdam):
|
||||
with torch.enable_grad():
|
||||
loss = closure()
|
||||
|
||||
self._pre_step('exp_avg', 'exp_avg_sq')
|
||||
self._pre_step("exp_avg", "exp_avg_sq")
|
||||
for _, group in enumerate(self.param_groups):
|
||||
g_l, p_l, m_l, v_l = [], [], [], []
|
||||
group_step = 0
|
||||
for _, p in enumerate(group['params']):
|
||||
|
||||
for _, p in enumerate(group["params"]):
|
||||
if p.grad is None:
|
||||
continue
|
||||
|
||||
@@ -99,54 +107,87 @@ class HybridAdam(CPUAdam):
|
||||
|
||||
target_device = p.device
|
||||
if len(state) == 0:
|
||||
state['step'] = 0
|
||||
state["step"] = 0
|
||||
|
||||
# FIXME(ver217): CPU adam kernel only supports fp32 states now
|
||||
assert p.dtype is torch.float, "HybridAdam only support fp32 parameters"
|
||||
# gradient momentums
|
||||
state['exp_avg'] = torch.zeros_like(p, device=target_device)
|
||||
state["exp_avg"] = torch.zeros_like(p, device=target_device)
|
||||
# gradient variances
|
||||
state['exp_avg_sq'] = torch.zeros_like(p, device=target_device)
|
||||
state["exp_avg_sq"] = torch.zeros_like(p, device=target_device)
|
||||
self._post_state_init(p)
|
||||
|
||||
state['step'] += 1
|
||||
group_step = state['step']
|
||||
beta1, beta2 = group['betas']
|
||||
state["step"] += 1
|
||||
group_step = state["step"]
|
||||
beta1, beta2 = group["betas"]
|
||||
|
||||
if target_device.type == 'cpu':
|
||||
assert state['exp_avg'].device.type == 'cpu', "exp_avg should stay on cpu"
|
||||
assert state['exp_avg_sq'].device.type == 'cpu', "exp_avg should stay on cpu"
|
||||
self._pre_update(p, 'exp_avg', 'exp_avg_sq')
|
||||
if target_device.type == "cpu":
|
||||
assert state["exp_avg"].device.type == "cpu", "exp_avg should stay on cpu"
|
||||
assert state["exp_avg_sq"].device.type == "cpu", "exp_avg should stay on cpu"
|
||||
self._pre_update(p, "exp_avg", "exp_avg_sq")
|
||||
if p.grad.dtype is torch.bfloat16:
|
||||
# cpu adam kernel does not support bf16 now
|
||||
bias_correction1 = 1 - beta1**state['step']
|
||||
bias_correction2 = 1 - beta2**state['step']
|
||||
self.torch_adam_update(p.data, p.grad.data, state['exp_avg'], state['exp_avg_sq'], group['lr'],
|
||||
beta1, beta2, group['eps'], group['weight_decay'], bias_correction1,
|
||||
bias_correction2, self.adamw_mode)
|
||||
bias_correction1 = 1 - beta1 ** state["step"]
|
||||
bias_correction2 = 1 - beta2 ** state["step"]
|
||||
self.torch_adam_update(
|
||||
p.data,
|
||||
p.grad.data,
|
||||
state["exp_avg"],
|
||||
state["exp_avg_sq"],
|
||||
group["lr"],
|
||||
beta1,
|
||||
beta2,
|
||||
group["eps"],
|
||||
group["weight_decay"],
|
||||
bias_correction1,
|
||||
bias_correction2,
|
||||
self.adamw_mode,
|
||||
)
|
||||
else:
|
||||
self.cpu_adam_op.step(state['step'], group['lr'], beta1, beta2, group['eps'],
|
||||
group['weight_decay'], group['bias_correction'], p.data, p.grad.data,
|
||||
state['exp_avg'], state['exp_avg_sq'], div_scale)
|
||||
self._post_update(p, 'exp_avg', 'exp_avg_sq')
|
||||
self.cpu_adam_op.step(
|
||||
state["step"],
|
||||
group["lr"],
|
||||
beta1,
|
||||
beta2,
|
||||
group["eps"],
|
||||
group["weight_decay"],
|
||||
group["bias_correction"],
|
||||
p.data,
|
||||
p.grad.data,
|
||||
state["exp_avg"],
|
||||
state["exp_avg_sq"],
|
||||
div_scale,
|
||||
)
|
||||
self._post_update(p, "exp_avg", "exp_avg_sq")
|
||||
|
||||
elif target_device.type == 'cuda':
|
||||
assert state['exp_avg'].device.type == 'cuda', "exp_avg should stay on cuda"
|
||||
assert state['exp_avg_sq'].device.type == 'cuda', "exp_avg should stay on cuda"
|
||||
elif target_device.type == "cuda":
|
||||
assert state["exp_avg"].device.type == "cuda", "exp_avg should stay on cuda"
|
||||
assert state["exp_avg_sq"].device.type == "cuda", "exp_avg should stay on cuda"
|
||||
|
||||
# record the state by group and update at once
|
||||
g_l.append(p.grad.data)
|
||||
p_l.append(p.data)
|
||||
m_l.append(state['exp_avg'])
|
||||
v_l.append(state['exp_avg_sq'])
|
||||
m_l.append(state["exp_avg"])
|
||||
v_l.append(state["exp_avg_sq"])
|
||||
|
||||
else:
|
||||
raise RuntimeError
|
||||
if len(g_l) > 0:
|
||||
adamw_mode = 1 if self.adamw_mode else 0
|
||||
bias_correction = 1 if group['bias_correction'] else 0
|
||||
multi_tensor_applier(self.gpu_adam_op, self._dummy_overflow_buf, [g_l, p_l, m_l, v_l], group['lr'],
|
||||
group['betas'][0], group['betas'][1], group['eps'], group_step, adamw_mode,
|
||||
bias_correction, group['weight_decay'], div_scale)
|
||||
bias_correction = 1 if group["bias_correction"] else 0
|
||||
multi_tensor_applier(
|
||||
self.gpu_adam_op,
|
||||
self._dummy_overflow_buf,
|
||||
[g_l, p_l, m_l, v_l],
|
||||
group["lr"],
|
||||
group["betas"][0],
|
||||
group["betas"][1],
|
||||
group["eps"],
|
||||
group_step,
|
||||
adamw_mode,
|
||||
bias_correction,
|
||||
group["weight_decay"],
|
||||
div_scale,
|
||||
)
|
||||
self._post_step()
|
||||
return loss
|
||||
|
@@ -51,27 +51,27 @@ class Lamb(Optimizer):
|
||||
loss = closure()
|
||||
|
||||
for group in self.param_groups:
|
||||
for p in group['params']:
|
||||
for p in group["params"]:
|
||||
if p.grad is None:
|
||||
continue
|
||||
grad = p.grad.data
|
||||
if grad.is_sparse:
|
||||
raise RuntimeError('Lamb does not support sparse gradients, consider SparseAdam instead.')
|
||||
raise RuntimeError("Lamb does not support sparse gradients, consider SparseAdam instead.")
|
||||
|
||||
state = self.state[p]
|
||||
|
||||
# State initialization
|
||||
if len(state) == 0:
|
||||
state['step'] = 0
|
||||
state["step"] = 0
|
||||
# Exponential moving average of gradient values
|
||||
state['exp_avg'] = torch.zeros_like(p)
|
||||
state["exp_avg"] = torch.zeros_like(p)
|
||||
# Exponential moving average of squared gradient values
|
||||
state['exp_avg_sq'] = torch.zeros_like(p)
|
||||
state["exp_avg_sq"] = torch.zeros_like(p)
|
||||
|
||||
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
|
||||
beta1, beta2 = group['betas']
|
||||
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
|
||||
beta1, beta2 = group["betas"]
|
||||
|
||||
state['step'] += 1
|
||||
state["step"] += 1
|
||||
|
||||
# Decay the first and second moment running average coefficient
|
||||
# m_t
|
||||
@@ -84,22 +84,22 @@ class Lamb(Optimizer):
|
||||
# bias_correction2 = 1 - beta2 ** state['step']
|
||||
# Apply bias to lr to avoid broadcast.
|
||||
# * math.sqrt(bias_correction2) / bias_correction1
|
||||
step_size = group['lr']
|
||||
step_size = group["lr"]
|
||||
|
||||
weight_norm = p.data.pow(2).sum().sqrt()
|
||||
|
||||
adam_step = exp_avg / exp_avg_sq.sqrt().add(group['eps'])
|
||||
if group['weight_decay'] != 0:
|
||||
adam_step.add_(p.data, alpha=group['weight_decay'])
|
||||
adam_step = exp_avg / exp_avg_sq.sqrt().add(group["eps"])
|
||||
if group["weight_decay"] != 0:
|
||||
adam_step.add_(p.data, alpha=group["weight_decay"])
|
||||
|
||||
adam_norm = adam_step.pow(2).sum().sqrt()
|
||||
if weight_norm == 0 or adam_norm == 0:
|
||||
trust_ratio = 1
|
||||
else:
|
||||
trust_ratio = weight_norm / adam_norm
|
||||
state['weight_norm'] = weight_norm
|
||||
state['adam_norm'] = adam_norm
|
||||
state['trust_ratio'] = trust_ratio
|
||||
state["weight_norm"] = weight_norm
|
||||
state["adam_norm"] = adam_norm
|
||||
state["trust_ratio"] = trust_ratio
|
||||
if self.adam:
|
||||
trust_ratio = 1
|
||||
|
||||
|
@@ -19,13 +19,9 @@ class Lars(Optimizer):
|
||||
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
params: Iterable[torch.nn.Parameter],
|
||||
lr=1e-3,
|
||||
momentum=0,
|
||||
eeta=1e-3,
|
||||
weight_decay=0,
|
||||
epsilon=0.0) -> None:
|
||||
def __init__(
|
||||
self, params: Iterable[torch.nn.Parameter], lr=1e-3, momentum=0, eeta=1e-3, weight_decay=0, epsilon=0.0
|
||||
) -> None:
|
||||
if not isinstance(lr, float) or lr < 0.0:
|
||||
raise ValueError("Invalid learning rate: {}".format(lr))
|
||||
if momentum < 0.0:
|
||||
@@ -54,14 +50,14 @@ class Lars(Optimizer):
|
||||
loss = closure()
|
||||
|
||||
for group in self.param_groups:
|
||||
weight_decay = group['weight_decay']
|
||||
momentum = group['momentum']
|
||||
eeta = group['eeta']
|
||||
lr = group['lr']
|
||||
lars = group['lars']
|
||||
eps = group['epsilon']
|
||||
weight_decay = group["weight_decay"]
|
||||
momentum = group["momentum"]
|
||||
eeta = group["eeta"]
|
||||
lr = group["lr"]
|
||||
lars = group["lars"]
|
||||
eps = group["epsilon"]
|
||||
|
||||
for p in group['params']:
|
||||
for p in group["params"]:
|
||||
if p.grad is None:
|
||||
continue
|
||||
decayed_grad = p.grad
|
||||
@@ -69,9 +65,11 @@ class Lars(Optimizer):
|
||||
if lars:
|
||||
w_norm = torch.norm(p)
|
||||
g_norm = torch.norm(p.grad)
|
||||
trust_ratio = torch.where(w_norm > 0 and g_norm > 0,
|
||||
eeta * w_norm / (g_norm + weight_decay * w_norm + eps),
|
||||
torch.ones_like(w_norm))
|
||||
trust_ratio = torch.where(
|
||||
w_norm > 0 and g_norm > 0,
|
||||
eeta * w_norm / (g_norm + weight_decay * w_norm + eps),
|
||||
torch.ones_like(w_norm),
|
||||
)
|
||||
trust_ratio.clamp_(0.0, 50)
|
||||
scaled_lr *= trust_ratio.item()
|
||||
if weight_decay != 0:
|
||||
@@ -80,10 +78,10 @@ class Lars(Optimizer):
|
||||
|
||||
if momentum != 0:
|
||||
param_state = self.state[p]
|
||||
if 'momentum_buffer' not in param_state:
|
||||
buf = param_state['momentum_buffer'] = torch.clone(decayed_grad).detach()
|
||||
if "momentum_buffer" not in param_state:
|
||||
buf = param_state["momentum_buffer"] = torch.clone(decayed_grad).detach()
|
||||
else:
|
||||
buf = param_state['momentum_buffer']
|
||||
buf = param_state["momentum_buffer"]
|
||||
buf.mul_(momentum).add_(decayed_grad)
|
||||
decayed_grad = buf
|
||||
|
||||
|
@@ -19,13 +19,11 @@ class NVMeOptimizer(torch.optim.Optimizer):
|
||||
|
||||
Raises:
|
||||
ImportError: Raise if ``tensornvme`` is not installed.
|
||||
"""
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
params,
|
||||
defaults: dict,
|
||||
nvme_offload_fraction: float = 0.0,
|
||||
offload_dir: Optional[str] = None) -> None:
|
||||
def __init__(
|
||||
self, params, defaults: dict, nvme_offload_fraction: float = 0.0, offload_dir: Optional[str] = None
|
||||
) -> None:
|
||||
assert 0.0 <= nvme_offload_fraction <= 1.0
|
||||
super().__init__(params, defaults)
|
||||
self.nvme_offload_fraction = float(nvme_offload_fraction)
|
||||
@@ -34,9 +32,9 @@ class NVMeOptimizer(torch.optim.Optimizer):
|
||||
from tensornvme import DiskOffloader
|
||||
from tensornvme._C import get_backends
|
||||
except ModuleNotFoundError:
|
||||
raise ModuleNotFoundError('Please install tensornvme to use NVMeOptimizer')
|
||||
raise ModuleNotFoundError("Please install tensornvme to use NVMeOptimizer")
|
||||
self.offload_dir = offload_dir or tempfile.mkdtemp()
|
||||
backend = 'uring' if 'uring' in get_backends() else 'aio'
|
||||
backend = "uring" if "uring" in get_backends() else "aio"
|
||||
self.offloader = DiskOffloader(self.offload_dir, 8, backend=backend)
|
||||
else:
|
||||
self.offload_dir = None
|
||||
@@ -53,13 +51,17 @@ class NVMeOptimizer(torch.optim.Optimizer):
|
||||
def _get_numel(self) -> int:
|
||||
numel = 0
|
||||
for group in self.param_groups:
|
||||
for p in group['params']:
|
||||
for p in group["params"]:
|
||||
numel += p.storage().size()
|
||||
return numel
|
||||
|
||||
def _post_state_init(self, param: Parameter) -> None:
|
||||
numel = param.storage().size()
|
||||
if self.offloader is not None and param.device.type == 'cpu' and numel + self.offloaded_numel <= self.can_offload_numel:
|
||||
if (
|
||||
self.offloader is not None
|
||||
and param.device.type == "cpu"
|
||||
and numel + self.offloaded_numel <= self.can_offload_numel
|
||||
):
|
||||
self.is_on_nvme[param] = True
|
||||
self.offloaded_numel += numel
|
||||
else:
|
||||
@@ -70,11 +72,11 @@ class NVMeOptimizer(torch.optim.Optimizer):
|
||||
return
|
||||
assert len(self.prefetch_params) == 0 and len(self.param_to_prefetch_idx) == 0
|
||||
for group in self.param_groups:
|
||||
for p in group['params']:
|
||||
for p in group["params"]:
|
||||
if p.grad is None:
|
||||
continue
|
||||
if len(self.state[p]) > 0 and self.is_on_nvme[p]:
|
||||
assert p.device.type == 'cpu'
|
||||
assert p.device.type == "cpu"
|
||||
self.param_to_prefetch_idx[p] = len(self.prefetch_params)
|
||||
self.prefetch_params.append(p)
|
||||
|
||||
@@ -156,7 +158,7 @@ class NVMeOptimizer(torch.optim.Optimizer):
|
||||
super().load_state_dict(state_dict)
|
||||
|
||||
def __del__(self) -> None:
|
||||
if getattr(self, 'offloader', None) is not None:
|
||||
if getattr(self, "offloader", None) is not None:
|
||||
del self.offloader
|
||||
if os.path.exists(self.offload_dir):
|
||||
try:
|
||||
|
Reference in New Issue
Block a user