diff --git a/tests/test_optimizer/unittest_cpu_adam.py b/tests/test_optimizer/unittest_cpu_adam.py index 2f1e62174..401fc5241 100644 --- a/tests/test_optimizer/unittest_cpu_adam.py +++ b/tests/test_optimizer/unittest_cpu_adam.py @@ -29,12 +29,12 @@ import math import torch -import colossalai try: import cpu_adam except ImportError: raise ImportError("import cpu_adam error") + def torch_adam_update( step, lr, @@ -42,7 +42,6 @@ def torch_adam_update( beta2, eps, weight_decay, - bias_correction, param, grad, exp_avg, @@ -52,8 +51,8 @@ def torch_adam_update( ): if loss_scale > 0: grad.div_(loss_scale) - bias_correction1 = 1 - beta1 ** step - bias_correction2 = 1 - beta2 ** step + bias_correction1 = 1 - beta1**step + bias_correction2 = 1 - beta2**step if weight_decay != 0: if use_adamw: @@ -73,12 +72,13 @@ def torch_adam_update( class Test(): + def __init__(self): self.opt_id = 0 - + def assertLess(self, data_diff, threshold, msg): assert data_diff < threshold, msg - + def assertTrue(self, condition, msg): assert condition, msg @@ -89,7 +89,6 @@ class Test(): eps, beta1, beta2, - weight_decay, shape, grad_dtype, @@ -118,8 +117,8 @@ class Test(): eps, weight_decay, True, - p_data.view(-1), # fp32 data - p_grad.view(-1), # fp32 grad + p_data.view(-1), # fp32 data + p_grad.view(-1), # fp32 grad exp_avg.view(-1), exp_avg_sq.view(-1), loss_scale, @@ -132,15 +131,14 @@ class Test(): beta2, eps, weight_decay, - True, - p_data_copy, # fp32 data - p_grad_copy, # fp32 grad + p_data_copy, # fp32 data + p_grad_copy, # fp32 grad exp_avg_copy, exp_avg_sq_copy, loss_scale, use_adamw, ) - + if loss_scale > 0: p_grad.div_(loss_scale) @@ -158,16 +156,14 @@ class Test(): max_exp_avg_diff = torch.max(torch.abs(exp_avg_copy - exp_avg)) self.assertTrue(max_exp_avg_diff < threshold, f"max_exp_avg_diff {max_exp_avg_diff}") max_exp_avg_sq_diff = torch.max(torch.abs(exp_avg_sq_copy - exp_avg_sq)) - self.assertTrue( - max_exp_avg_sq_diff < threshold, f"max_exp_avg_sq_diff {max_exp_avg_sq_diff}" - ) + self.assertTrue(max_exp_avg_sq_diff < threshold, f"max_exp_avg_sq_diff {max_exp_avg_sq_diff}") def test_cpu_adam(self): lr = 0.9 eps = 1e-6 weight_decay = 0 for use_adamw in [False, True]: - for shape in [(1023, ), (32, 1024)]: + for shape in [(23,), (8, 24)]: for step in range(1, 2): for lr in [0.01]: for eps in [1e-8]: @@ -175,7 +171,7 @@ class Test(): for beta2 in [0.999]: for weight_decay in [0.001]: for grad_dtype in [torch.half, torch.float]: - for loss_scale in [-1, 2 ** 5]: + for loss_scale in [-1, 2**5]: self.check_res( step, lr, @@ -191,7 +187,11 @@ class Test(): ) +def test_cpu_adam(): + test_case = Test() + test_case.test_cpu_adam() + + if __name__ == "__main__": test = Test() test.test_cpu_adam() - print('All is well.')