mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-10-14 11:18:58 +00:00
[cuda] modify the fused adam, support hybrid of fp16 and fp32 (#497)
This commit is contained in:
98
tests/test_optimizer/unittest_fused_adam_kernel.py
Normal file
98
tests/test_optimizer/unittest_fused_adam_kernel.py
Normal file
@@ -0,0 +1,98 @@
|
||||
from numpy import dtype
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import math
|
||||
|
||||
from colossalai.testing import parameterize
|
||||
from colossalai.utils import multi_tensor_applier
|
||||
|
||||
def torch_adam_update(
|
||||
step,
|
||||
lr,
|
||||
beta1,
|
||||
beta2,
|
||||
eps,
|
||||
weight_decay,
|
||||
param,
|
||||
grad,
|
||||
exp_avg,
|
||||
exp_avg_sq,
|
||||
loss_scale,
|
||||
use_adamw,
|
||||
):
|
||||
if loss_scale > 0:
|
||||
grad.div_(loss_scale)
|
||||
bias_correction1 = 1 - beta1**step
|
||||
bias_correction2 = 1 - beta2**step
|
||||
|
||||
if weight_decay != 0:
|
||||
if use_adamw:
|
||||
# Perform stepweight decay
|
||||
param.mul_(1 - lr * weight_decay)
|
||||
else:
|
||||
grad = grad.add(param, alpha=weight_decay)
|
||||
|
||||
# Decay the first and second moment running average coefficient
|
||||
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
|
||||
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
|
||||
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps)
|
||||
|
||||
step_size = lr / bias_correction1
|
||||
|
||||
param.addcdiv_(exp_avg, denom, value=-step_size)
|
||||
|
||||
|
||||
@parameterize('adamw', [False, True])
|
||||
@parameterize('step', [1, 2])
|
||||
@parameterize('p_dtype', [torch.float, torch.half])
|
||||
@parameterize('g_dtype', [torch.float, torch.half])
|
||||
def test_adam(adamw, step, p_dtype, g_dtype):
|
||||
try:
|
||||
import colossal_C
|
||||
fused_adam = colossal_C.multi_tensor_adam
|
||||
dummy_overflow_buf = torch.cuda.IntTensor([0])
|
||||
except:
|
||||
raise ImportError("No colossal_C kernel installed.")
|
||||
|
||||
count = 0
|
||||
|
||||
for i in range(1024):
|
||||
p = torch.rand(64, dtype=p_dtype).cuda()
|
||||
p_copy = p.clone().float()
|
||||
g = torch.rand(p.shape, dtype=g_dtype).cuda()
|
||||
g_copy = g.clone().float()
|
||||
m = torch.rand(p.shape).cuda()
|
||||
m_copy = m.clone()
|
||||
v = torch.rand(p.shape).cuda()
|
||||
v_copy = v.clone()
|
||||
|
||||
lr = 1e-3
|
||||
beta1, beta2 = 0.9, 0.999
|
||||
eps = 1e-8
|
||||
weight_decay = 0
|
||||
|
||||
multi_tensor_applier(fused_adam, dummy_overflow_buf, [[g], [p], [m], [v]],
|
||||
lr, beta1, beta2, eps, step, adamw,
|
||||
True, weight_decay)
|
||||
|
||||
torch_adam_update(
|
||||
step,
|
||||
lr,
|
||||
beta1,
|
||||
beta2,
|
||||
eps,
|
||||
weight_decay,
|
||||
p_copy, # fp32 data
|
||||
g_copy, # fp32 grad
|
||||
m_copy,
|
||||
v_copy,
|
||||
-1,
|
||||
adamw,
|
||||
)
|
||||
|
||||
if torch.isnan(p).any() or torch.isnan(p_copy).any():
|
||||
count += 1
|
||||
continue
|
||||
assert count < 200, "too many nans"
|
||||
assert torch.allclose(p.to(torch.float), p_copy.to(torch.float), 1e-5, 1e-5), f"failed check, adamw {adamw}, p_dtype {p_dtype}, g_dtype {g_dtype}"
|
Reference in New Issue
Block a user