mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-20 04:32:47 +00:00
[hotfix] run cpu adam unittest in pytest (#424)
This commit is contained in:
parent
54229cd33e
commit
5d7dc3525b
@ -29,12 +29,12 @@
|
|||||||
|
|
||||||
import math
|
import math
|
||||||
import torch
|
import torch
|
||||||
import colossalai
|
|
||||||
try:
|
try:
|
||||||
import cpu_adam
|
import cpu_adam
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError("import cpu_adam error")
|
raise ImportError("import cpu_adam error")
|
||||||
|
|
||||||
|
|
||||||
def torch_adam_update(
|
def torch_adam_update(
|
||||||
step,
|
step,
|
||||||
lr,
|
lr,
|
||||||
@ -42,7 +42,6 @@ def torch_adam_update(
|
|||||||
beta2,
|
beta2,
|
||||||
eps,
|
eps,
|
||||||
weight_decay,
|
weight_decay,
|
||||||
bias_correction,
|
|
||||||
param,
|
param,
|
||||||
grad,
|
grad,
|
||||||
exp_avg,
|
exp_avg,
|
||||||
@ -73,6 +72,7 @@ def torch_adam_update(
|
|||||||
|
|
||||||
|
|
||||||
class Test():
|
class Test():
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.opt_id = 0
|
self.opt_id = 0
|
||||||
|
|
||||||
@ -89,7 +89,6 @@ class Test():
|
|||||||
eps,
|
eps,
|
||||||
beta1,
|
beta1,
|
||||||
beta2,
|
beta2,
|
||||||
|
|
||||||
weight_decay,
|
weight_decay,
|
||||||
shape,
|
shape,
|
||||||
grad_dtype,
|
grad_dtype,
|
||||||
@ -132,7 +131,6 @@ class Test():
|
|||||||
beta2,
|
beta2,
|
||||||
eps,
|
eps,
|
||||||
weight_decay,
|
weight_decay,
|
||||||
True,
|
|
||||||
p_data_copy, # fp32 data
|
p_data_copy, # fp32 data
|
||||||
p_grad_copy, # fp32 grad
|
p_grad_copy, # fp32 grad
|
||||||
exp_avg_copy,
|
exp_avg_copy,
|
||||||
@ -158,16 +156,14 @@ class Test():
|
|||||||
max_exp_avg_diff = torch.max(torch.abs(exp_avg_copy - exp_avg))
|
max_exp_avg_diff = torch.max(torch.abs(exp_avg_copy - exp_avg))
|
||||||
self.assertTrue(max_exp_avg_diff < threshold, f"max_exp_avg_diff {max_exp_avg_diff}")
|
self.assertTrue(max_exp_avg_diff < threshold, f"max_exp_avg_diff {max_exp_avg_diff}")
|
||||||
max_exp_avg_sq_diff = torch.max(torch.abs(exp_avg_sq_copy - exp_avg_sq))
|
max_exp_avg_sq_diff = torch.max(torch.abs(exp_avg_sq_copy - exp_avg_sq))
|
||||||
self.assertTrue(
|
self.assertTrue(max_exp_avg_sq_diff < threshold, f"max_exp_avg_sq_diff {max_exp_avg_sq_diff}")
|
||||||
max_exp_avg_sq_diff < threshold, f"max_exp_avg_sq_diff {max_exp_avg_sq_diff}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_cpu_adam(self):
|
def test_cpu_adam(self):
|
||||||
lr = 0.9
|
lr = 0.9
|
||||||
eps = 1e-6
|
eps = 1e-6
|
||||||
weight_decay = 0
|
weight_decay = 0
|
||||||
for use_adamw in [False, True]:
|
for use_adamw in [False, True]:
|
||||||
for shape in [(1023, ), (32, 1024)]:
|
for shape in [(23,), (8, 24)]:
|
||||||
for step in range(1, 2):
|
for step in range(1, 2):
|
||||||
for lr in [0.01]:
|
for lr in [0.01]:
|
||||||
for eps in [1e-8]:
|
for eps in [1e-8]:
|
||||||
@ -191,7 +187,11 @@ class Test():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_cpu_adam():
|
||||||
|
test_case = Test()
|
||||||
|
test_case.test_cpu_adam()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test = Test()
|
test = Test()
|
||||||
test.test_cpu_adam()
|
test.test_cpu_adam()
|
||||||
print('All is well.')
|
|
||||||
|
Loading…
Reference in New Issue
Block a user