mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-20 12:43:55 +00:00
[hotfix] run cpu adam unittest in pytest (#424)
This commit is contained in:
parent
54229cd33e
commit
5d7dc3525b
@ -29,12 +29,12 @@
|
|||||||
|
|
||||||
import math
|
import math
|
||||||
import torch
|
import torch
|
||||||
import colossalai
|
|
||||||
try:
|
try:
|
||||||
import cpu_adam
|
import cpu_adam
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError("import cpu_adam error")
|
raise ImportError("import cpu_adam error")
|
||||||
|
|
||||||
|
|
||||||
def torch_adam_update(
|
def torch_adam_update(
|
||||||
step,
|
step,
|
||||||
lr,
|
lr,
|
||||||
@ -42,7 +42,6 @@ def torch_adam_update(
|
|||||||
beta2,
|
beta2,
|
||||||
eps,
|
eps,
|
||||||
weight_decay,
|
weight_decay,
|
||||||
bias_correction,
|
|
||||||
param,
|
param,
|
||||||
grad,
|
grad,
|
||||||
exp_avg,
|
exp_avg,
|
||||||
@ -52,8 +51,8 @@ def torch_adam_update(
|
|||||||
):
|
):
|
||||||
if loss_scale > 0:
|
if loss_scale > 0:
|
||||||
grad.div_(loss_scale)
|
grad.div_(loss_scale)
|
||||||
bias_correction1 = 1 - beta1 ** step
|
bias_correction1 = 1 - beta1**step
|
||||||
bias_correction2 = 1 - beta2 ** step
|
bias_correction2 = 1 - beta2**step
|
||||||
|
|
||||||
if weight_decay != 0:
|
if weight_decay != 0:
|
||||||
if use_adamw:
|
if use_adamw:
|
||||||
@ -73,6 +72,7 @@ def torch_adam_update(
|
|||||||
|
|
||||||
|
|
||||||
class Test():
|
class Test():
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.opt_id = 0
|
self.opt_id = 0
|
||||||
|
|
||||||
@ -89,7 +89,6 @@ class Test():
|
|||||||
eps,
|
eps,
|
||||||
beta1,
|
beta1,
|
||||||
beta2,
|
beta2,
|
||||||
|
|
||||||
weight_decay,
|
weight_decay,
|
||||||
shape,
|
shape,
|
||||||
grad_dtype,
|
grad_dtype,
|
||||||
@ -118,8 +117,8 @@ class Test():
|
|||||||
eps,
|
eps,
|
||||||
weight_decay,
|
weight_decay,
|
||||||
True,
|
True,
|
||||||
p_data.view(-1), # fp32 data
|
p_data.view(-1), # fp32 data
|
||||||
p_grad.view(-1), # fp32 grad
|
p_grad.view(-1), # fp32 grad
|
||||||
exp_avg.view(-1),
|
exp_avg.view(-1),
|
||||||
exp_avg_sq.view(-1),
|
exp_avg_sq.view(-1),
|
||||||
loss_scale,
|
loss_scale,
|
||||||
@ -132,9 +131,8 @@ class Test():
|
|||||||
beta2,
|
beta2,
|
||||||
eps,
|
eps,
|
||||||
weight_decay,
|
weight_decay,
|
||||||
True,
|
p_data_copy, # fp32 data
|
||||||
p_data_copy, # fp32 data
|
p_grad_copy, # fp32 grad
|
||||||
p_grad_copy, # fp32 grad
|
|
||||||
exp_avg_copy,
|
exp_avg_copy,
|
||||||
exp_avg_sq_copy,
|
exp_avg_sq_copy,
|
||||||
loss_scale,
|
loss_scale,
|
||||||
@ -158,16 +156,14 @@ class Test():
|
|||||||
max_exp_avg_diff = torch.max(torch.abs(exp_avg_copy - exp_avg))
|
max_exp_avg_diff = torch.max(torch.abs(exp_avg_copy - exp_avg))
|
||||||
self.assertTrue(max_exp_avg_diff < threshold, f"max_exp_avg_diff {max_exp_avg_diff}")
|
self.assertTrue(max_exp_avg_diff < threshold, f"max_exp_avg_diff {max_exp_avg_diff}")
|
||||||
max_exp_avg_sq_diff = torch.max(torch.abs(exp_avg_sq_copy - exp_avg_sq))
|
max_exp_avg_sq_diff = torch.max(torch.abs(exp_avg_sq_copy - exp_avg_sq))
|
||||||
self.assertTrue(
|
self.assertTrue(max_exp_avg_sq_diff < threshold, f"max_exp_avg_sq_diff {max_exp_avg_sq_diff}")
|
||||||
max_exp_avg_sq_diff < threshold, f"max_exp_avg_sq_diff {max_exp_avg_sq_diff}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_cpu_adam(self):
|
def test_cpu_adam(self):
|
||||||
lr = 0.9
|
lr = 0.9
|
||||||
eps = 1e-6
|
eps = 1e-6
|
||||||
weight_decay = 0
|
weight_decay = 0
|
||||||
for use_adamw in [False, True]:
|
for use_adamw in [False, True]:
|
||||||
for shape in [(1023, ), (32, 1024)]:
|
for shape in [(23,), (8, 24)]:
|
||||||
for step in range(1, 2):
|
for step in range(1, 2):
|
||||||
for lr in [0.01]:
|
for lr in [0.01]:
|
||||||
for eps in [1e-8]:
|
for eps in [1e-8]:
|
||||||
@ -175,7 +171,7 @@ class Test():
|
|||||||
for beta2 in [0.999]:
|
for beta2 in [0.999]:
|
||||||
for weight_decay in [0.001]:
|
for weight_decay in [0.001]:
|
||||||
for grad_dtype in [torch.half, torch.float]:
|
for grad_dtype in [torch.half, torch.float]:
|
||||||
for loss_scale in [-1, 2 ** 5]:
|
for loss_scale in [-1, 2**5]:
|
||||||
self.check_res(
|
self.check_res(
|
||||||
step,
|
step,
|
||||||
lr,
|
lr,
|
||||||
@ -191,7 +187,11 @@ class Test():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_cpu_adam():
|
||||||
|
test_case = Test()
|
||||||
|
test_case.test_cpu_adam()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test = Test()
|
test = Test()
|
||||||
test.test_cpu_adam()
|
test.test_cpu_adam()
|
||||||
print('All is well.')
|
|
||||||
|
Loading…
Reference in New Issue
Block a user