mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 17:17:05 +00:00
[cuda] modify the fused adam, support hybrid of fp16 and fp32 (#497)
This commit is contained in:
@@ -22,7 +22,7 @@ typedef enum
|
||||
|
||||
using MATH_T = float;
|
||||
|
||||
template <typename T>
|
||||
template <typename T_g, typename T_p>
|
||||
struct AdamFunctor
|
||||
{
|
||||
__device__ __forceinline__ void operator()(
|
||||
@@ -50,16 +50,16 @@ struct AdamFunctor
|
||||
int chunk_idx = tl.block_to_chunk[blockIdx.x];
|
||||
int n = tl.sizes[tensor_loc];
|
||||
|
||||
T *g = (T *)tl.addresses[0][tensor_loc];
|
||||
T_g *g = (T_g *)tl.addresses[0][tensor_loc];
|
||||
g += chunk_idx * chunk_size;
|
||||
|
||||
T *p = (T *)tl.addresses[1][tensor_loc];
|
||||
T_p *p = (T_p *)tl.addresses[1][tensor_loc];
|
||||
p += chunk_idx * chunk_size;
|
||||
|
||||
T *m = (T *)tl.addresses[2][tensor_loc];
|
||||
T_p *m = (T_p *)tl.addresses[2][tensor_loc];
|
||||
m += chunk_idx * chunk_size;
|
||||
|
||||
T *v = (T *)tl.addresses[3][tensor_loc];
|
||||
T_p *v = (T_p *)tl.addresses[3][tensor_loc];
|
||||
v += chunk_idx * chunk_size;
|
||||
|
||||
n -= chunk_idx * chunk_size;
|
||||
@@ -155,15 +155,15 @@ void multi_tensor_adam_cuda(
|
||||
bias_correction2 = 1 - std::pow(beta2, step);
|
||||
}
|
||||
|
||||
// Assume single type across p,g,m1,m2 now
|
||||
DISPATCH_DOUBLE_FLOAT_AND_HALF(
|
||||
tensor_lists[0][0].scalar_type(), 0, "adam",
|
||||
DISPATCH_FLOAT_AND_HALF_FOR_G_P(
|
||||
tensor_lists[0][0].scalar_type(),
|
||||
tensor_lists[1][0].scalar_type(), 0, "adam",
|
||||
multi_tensor_apply<4>(
|
||||
BLOCK_SIZE,
|
||||
chunk_size,
|
||||
noop_flag,
|
||||
tensor_lists,
|
||||
AdamFunctor<scalar_t_0>(),
|
||||
AdamFunctor<g_scalar_t_0, p_scalar_t_0>(),
|
||||
beta1,
|
||||
beta2,
|
||||
bias_correction1,
|
||||
|
@@ -173,6 +173,36 @@
|
||||
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
|
||||
}
|
||||
|
||||
#define DISPATCH_FLOAT_AND_HALF_FOR_G_P(GTYPE, PTYPE, LEVEL, NAME, ...) \
|
||||
if (GTYPE == at::ScalarType::Float && PTYPE == at::ScalarType::Float) \
|
||||
{ \
|
||||
using g_scalar_t_##LEVEL = float; \
|
||||
using p_scalar_t_##LEVEL = float; \
|
||||
__VA_ARGS__; \
|
||||
} \
|
||||
else if (GTYPE == at::ScalarType::Float && PTYPE == at::ScalarType::Half) \
|
||||
{ \
|
||||
using g_scalar_t_##LEVEL = float; \
|
||||
using p_scalar_t_##LEVEL = at::Half; \
|
||||
__VA_ARGS__; \
|
||||
} \
|
||||
else if (GTYPE == at::ScalarType::Half && PTYPE == at::ScalarType::Float) \
|
||||
{ \
|
||||
using g_scalar_t_##LEVEL = at::Half; \
|
||||
using p_scalar_t_##LEVEL = float; \
|
||||
__VA_ARGS__; \
|
||||
} \
|
||||
else if (GTYPE == at::ScalarType::Half && PTYPE == at::ScalarType::Half) \
|
||||
{ \
|
||||
using g_scalar_t_##LEVEL = at::Half; \
|
||||
using p_scalar_t_##LEVEL = at::Half; \
|
||||
__VA_ARGS__; \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
AT_ERROR(#NAME, "not implemented for '", toString(GTYPE), toString(PTYPE), "'"); \
|
||||
} \
|
||||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ T reduce_block_into_lanes(T *x,
|
||||
T val,
|
||||
|
@@ -10,7 +10,7 @@ class FusedAdam(torch.optim.Optimizer):
|
||||
"""Implements Adam algorithm.
|
||||
|
||||
Currently GPU-only. Requires ColossalAI to be installed via
|
||||
``pip install -v --no-cache-dir --global-option="--cuda_ext" ./``.
|
||||
``pip install .``.
|
||||
|
||||
This version of fused Adam implements 2 fusions.
|
||||
|
||||
@@ -18,7 +18,7 @@ class FusedAdam(torch.optim.Optimizer):
|
||||
* A multi-tensor apply launch that batches the elementwise updates applied to all the model's parameters into one or a few kernel launches.
|
||||
|
||||
:class:`colossalai.nn.optimizer.FusedAdam` may be used as a drop-in replacement for ``torch.optim.AdamW``,
|
||||
or ``torch.optim.Adam`` with ``adam_w_mode=False``
|
||||
or ``torch.optim.Adam`` with ``adamw_mode=False``
|
||||
|
||||
:class:`colossalai.nn.optimizer.FusedAdam` may be used with or without Amp.
|
||||
|
||||
@@ -36,7 +36,7 @@ class FusedAdam(torch.optim.Optimizer):
|
||||
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
|
||||
algorithm from the paper `On the Convergence of Adam and Beyond`_
|
||||
(default: False) NOT SUPPORTED in FusedAdam!
|
||||
adam_w_mode (boolean, optional): Apply L2 regularization or weight decay
|
||||
adamw_mode (boolean, optional): Apply L2 regularization or weight decay
|
||||
True for decoupled weight decay(also known as AdamW) (default: True)
|
||||
set_grad_none (bool, optional): whether set grad to None when zero_grad()
|
||||
method is called. (default: True)
|
||||
@@ -53,7 +53,7 @@ class FusedAdam(torch.optim.Optimizer):
|
||||
bias_correction=True,
|
||||
betas=(0.9, 0.999),
|
||||
eps=1e-8,
|
||||
adam_w_mode=True,
|
||||
adamw_mode=True,
|
||||
weight_decay=0.,
|
||||
amsgrad=False,
|
||||
set_grad_none=True):
|
||||
@@ -62,7 +62,7 @@ class FusedAdam(torch.optim.Optimizer):
|
||||
raise RuntimeError('FusedAdam does not support the AMSGrad variant.')
|
||||
defaults = dict(lr=lr, bias_correction=bias_correction, betas=betas, eps=eps, weight_decay=weight_decay)
|
||||
super(FusedAdam, self).__init__(params, defaults)
|
||||
self.adam_w_mode = 1 if adam_w_mode else 0
|
||||
self.adamw_mode = 1 if adamw_mode else 0
|
||||
self.set_grad_none = set_grad_none
|
||||
if multi_tensor_applier.available:
|
||||
import colossal_C
|
||||
@@ -109,8 +109,7 @@ class FusedAdam(torch.optim.Optimizer):
|
||||
group['step'] = 1
|
||||
|
||||
# create lists for multi-tensor apply
|
||||
g_16, p_16, m_16, v_16 = [], [], [], []
|
||||
g_32, p_32, m_32, v_32 = [], [], [], []
|
||||
g_l, p_l, m_l, v_l = [], [], [], []
|
||||
|
||||
for p in group['params']:
|
||||
if p.grad is None:
|
||||
@@ -127,26 +126,16 @@ class FusedAdam(torch.optim.Optimizer):
|
||||
# Exponential moving average of squared gradient values
|
||||
state['exp_avg_sq'] = torch.zeros_like(p.data)
|
||||
|
||||
if p.dtype == torch.float16:
|
||||
g_16.append(p.grad.data)
|
||||
p_16.append(p.data)
|
||||
m_16.append(state['exp_avg'])
|
||||
v_16.append(state['exp_avg_sq'])
|
||||
elif p.dtype == torch.float32:
|
||||
g_32.append(p.grad.data)
|
||||
p_32.append(p.data)
|
||||
m_32.append(state['exp_avg'])
|
||||
v_32.append(state['exp_avg_sq'])
|
||||
else:
|
||||
if p.dtype not in [torch.float16, torch.float32]:
|
||||
raise RuntimeError('FusedAdam only support fp16 and fp32.')
|
||||
|
||||
g_l.append(p.grad.data)
|
||||
p_l.append(p.data)
|
||||
m_l.append(state['exp_avg'])
|
||||
v_l.append(state['exp_avg_sq'])
|
||||
|
||||
if (len(g_16) > 0):
|
||||
multi_tensor_applier(self.multi_tensor_adam, self._dummy_overflow_buf, [g_16, p_16, m_16, v_16],
|
||||
group['lr'], beta1, beta2, group['eps'], group['step'], self.adam_w_mode,
|
||||
bias_correction, group['weight_decay'])
|
||||
if (len(g_32) > 0):
|
||||
multi_tensor_applier(self.multi_tensor_adam, self._dummy_overflow_buf, [g_32, p_32, m_32, v_32],
|
||||
group['lr'], beta1, beta2, group['eps'], group['step'], self.adam_w_mode,
|
||||
bias_correction, group['weight_decay'])
|
||||
multi_tensor_applier(self.multi_tensor_adam, self._dummy_overflow_buf, [g_l, p_l, m_l, v_l],
|
||||
group['lr'], beta1, beta2, group['eps'], group['step'], self.adamw_mode,
|
||||
bias_correction, group['weight_decay'])
|
||||
|
||||
return loss
|
||||
|
Reference in New Issue
Block a user