[hotfix] run cpu adam unittest in pytest (#424)

This commit is contained in:
Jiarui Fang 2022-03-16 10:39:55 +08:00 committed by GitHub
parent 54229cd33e
commit 5d7dc3525b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -29,12 +29,12 @@
import math import math
import torch import torch
import colossalai
try: try:
import cpu_adam import cpu_adam
except ImportError: except ImportError:
raise ImportError("import cpu_adam error") raise ImportError("import cpu_adam error")
def torch_adam_update( def torch_adam_update(
step, step,
lr, lr,
@ -42,7 +42,6 @@ def torch_adam_update(
beta2, beta2,
eps, eps,
weight_decay, weight_decay,
bias_correction,
param, param,
grad, grad,
exp_avg, exp_avg,
@ -73,6 +72,7 @@ def torch_adam_update(
class Test(): class Test():
def __init__(self): def __init__(self):
self.opt_id = 0 self.opt_id = 0
@ -89,7 +89,6 @@ class Test():
eps, eps,
beta1, beta1,
beta2, beta2,
weight_decay, weight_decay,
shape, shape,
grad_dtype, grad_dtype,
@ -132,7 +131,6 @@ class Test():
beta2, beta2,
eps, eps,
weight_decay, weight_decay,
True,
p_data_copy, # fp32 data p_data_copy, # fp32 data
p_grad_copy, # fp32 grad p_grad_copy, # fp32 grad
exp_avg_copy, exp_avg_copy,
@ -158,16 +156,14 @@ class Test():
max_exp_avg_diff = torch.max(torch.abs(exp_avg_copy - exp_avg)) max_exp_avg_diff = torch.max(torch.abs(exp_avg_copy - exp_avg))
self.assertTrue(max_exp_avg_diff < threshold, f"max_exp_avg_diff {max_exp_avg_diff}") self.assertTrue(max_exp_avg_diff < threshold, f"max_exp_avg_diff {max_exp_avg_diff}")
max_exp_avg_sq_diff = torch.max(torch.abs(exp_avg_sq_copy - exp_avg_sq)) max_exp_avg_sq_diff = torch.max(torch.abs(exp_avg_sq_copy - exp_avg_sq))
self.assertTrue( self.assertTrue(max_exp_avg_sq_diff < threshold, f"max_exp_avg_sq_diff {max_exp_avg_sq_diff}")
max_exp_avg_sq_diff < threshold, f"max_exp_avg_sq_diff {max_exp_avg_sq_diff}"
)
def test_cpu_adam(self): def test_cpu_adam(self):
lr = 0.9 lr = 0.9
eps = 1e-6 eps = 1e-6
weight_decay = 0 weight_decay = 0
for use_adamw in [False, True]: for use_adamw in [False, True]:
for shape in [(1023, ), (32, 1024)]: for shape in [(23,), (8, 24)]:
for step in range(1, 2): for step in range(1, 2):
for lr in [0.01]: for lr in [0.01]:
for eps in [1e-8]: for eps in [1e-8]:
@ -191,7 +187,11 @@ class Test():
) )
def test_cpu_adam():
test_case = Test()
test_case.test_cpu_adam()
if __name__ == "__main__": if __name__ == "__main__":
test = Test() test = Test()
test.test_cpu_adam() test.test_cpu_adam()
print('All is well.')