mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 18:19:58 +00:00
[hotfix] fix CPUAdam kernel nullptr (#1410)
This commit is contained in:
@@ -54,7 +54,7 @@ def test_cpu_adam(adamw, step, p_dtype, g_dtype):
|
||||
beta1, beta2 = 0.9, 0.999
|
||||
eps = 1e-8
|
||||
weight_decay = 0
|
||||
|
||||
|
||||
for i in range(1024):
|
||||
p_data = torch.rand(64, dtype=p_dtype)
|
||||
p_data_copy = p_data.clone().float()
|
||||
@@ -67,13 +67,11 @@ def test_cpu_adam(adamw, step, p_dtype, g_dtype):
|
||||
|
||||
try:
|
||||
import cpu_adam
|
||||
cpu_adam_op = cpu_adam
|
||||
cpu_adam_op = cpu_adam.CPUAdamOptimizer(lr, beta1, beta2, eps, weight_decay, adamw)
|
||||
except:
|
||||
raise ImportError("Import cpu adam error, please install colossal from source code")
|
||||
|
||||
cpu_adam_op.create_adam(0, lr, beta1, beta2, eps, weight_decay, adamw, False)
|
||||
cpu_adam_op.adam_update(
|
||||
0,
|
||||
cpu_adam_op.step(
|
||||
step,
|
||||
lr,
|
||||
beta1,
|
@@ -8,9 +8,11 @@ from colossalai.testing import parameterize
|
||||
|
||||
|
||||
class FC(nn.Module):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.fc = nn.Sequential(nn.Linear(64, 64))
|
||||
|
||||
def forward(self, x):
|
||||
return self.fc(x)
|
||||
|
||||
@@ -37,7 +39,7 @@ def test_adam(adamw, p_dtype, g_dtype):
|
||||
|
||||
for d, l in zip(data, label):
|
||||
y = model(d)
|
||||
loss = ((l - y) ** 2).sum()
|
||||
loss = ((l - y)**2).sum()
|
||||
optim.zero_grad()
|
||||
loss.backward()
|
||||
if p_dtype != g_dtype:
|
||||
@@ -47,13 +49,13 @@ def test_adam(adamw, p_dtype, g_dtype):
|
||||
|
||||
for d, l in zip(data_copy, label):
|
||||
y = model_copy(d)
|
||||
loss = ((l - y) ** 2).sum()
|
||||
loss = ((l - y)**2).sum()
|
||||
torch_optim.zero_grad()
|
||||
loss.backward()
|
||||
torch_optim.step()
|
||||
|
||||
assert len(optim.param_groups[0]['params']) == len(torch_optim.param_groups[0]['params'])
|
||||
|
||||
|
||||
for i in range(len(optim.param_groups[0]['params'])):
|
||||
if torch.isnan(optim.param_groups[0]['params'][i]).any() \
|
||||
or torch.isnan(torch_optim.param_groups[0]['params'][i]).any():
|
@@ -7,6 +7,7 @@ import math
|
||||
from colossalai.testing import parameterize
|
||||
from colossalai.utils import multi_tensor_applier
|
||||
|
||||
|
||||
def torch_adam_update(
|
||||
step,
|
||||
lr,
|
||||
@@ -51,7 +52,7 @@ def test_adam(adamw, step, p_dtype, g_dtype):
|
||||
dummy_overflow_buf = torch.cuda.IntTensor([0])
|
||||
except:
|
||||
raise ImportError("No colossal_C kernel installed.")
|
||||
|
||||
|
||||
count = 0
|
||||
|
||||
for i in range(1024):
|
||||
@@ -69,26 +70,26 @@ def test_adam(adamw, step, p_dtype, g_dtype):
|
||||
eps = 1e-8
|
||||
weight_decay = 0
|
||||
|
||||
multi_tensor_applier(fused_adam, dummy_overflow_buf, [[g], [p], [m], [v]],
|
||||
lr, beta1, beta2, eps, step, adamw,
|
||||
True, weight_decay)
|
||||
multi_tensor_applier(fused_adam, dummy_overflow_buf, [[g], [p], [m], [v]], lr, beta1, beta2, eps, step, adamw,
|
||||
True, weight_decay)
|
||||
|
||||
torch_adam_update(
|
||||
step,
|
||||
lr,
|
||||
beta1,
|
||||
beta2,
|
||||
eps,
|
||||
weight_decay,
|
||||
p_copy, # fp32 data
|
||||
g_copy, # fp32 grad
|
||||
m_copy,
|
||||
v_copy,
|
||||
adamw,
|
||||
)
|
||||
|
||||
step,
|
||||
lr,
|
||||
beta1,
|
||||
beta2,
|
||||
eps,
|
||||
weight_decay,
|
||||
p_copy, # fp32 data
|
||||
g_copy, # fp32 grad
|
||||
m_copy,
|
||||
v_copy,
|
||||
adamw,
|
||||
)
|
||||
|
||||
if torch.isnan(p).any() or torch.isnan(p_copy).any():
|
||||
count += 1
|
||||
continue
|
||||
assert count < 200, "too many nans"
|
||||
assert torch.allclose(p.to(torch.float), p_copy.to(torch.float), 1e-5, 1e-5), f"failed check, adamw {adamw}, p_dtype {p_dtype}, g_dtype {g_dtype}"
|
||||
assert torch.allclose(p.to(torch.float), p_copy.to(torch.float), 1e-5,
|
||||
1e-5), f"failed check, adamw {adamw}, p_dtype {p_dtype}, g_dtype {g_dtype}"
|
@@ -17,7 +17,7 @@ def test_adam(adamw, device, p_dtype, g_dtype):
|
||||
rng_state = torch.get_rng_state()
|
||||
p = nn.Parameter(torch.rand(64).to(device, p_dtype))
|
||||
torch.set_rng_state(rng_state)
|
||||
p_copy = nn.Parameter(torch.rand(64).to(device).float())
|
||||
p_copy = nn.Parameter(torch.rand(64).to(device).float())
|
||||
|
||||
if adamw:
|
||||
optim = HybridAdam([p], lr=1e-3, adamw_mode=True)
|
||||
@@ -38,4 +38,4 @@ def test_adam(adamw, device, p_dtype, g_dtype):
|
||||
if torch.isnan(p.data).any() or torch.isnan(p_copy.data).any():
|
||||
continue
|
||||
assert torch.allclose(p.data, p_copy.data, 1e-4, 1e-2), \
|
||||
f"adaw mode {adamw}, device {device}, p_dtype {p_dtype}, g_dtype {g_dtype}"
|
||||
f"adaw mode {adamw}, device {device}, p_dtype {p_dtype}, g_dtype {g_dtype}"
|
Reference in New Issue
Block a user