[cuda] modify the fused adam, support hybrid of fp16 and fp32 (#497)

This commit is contained in:
LuGY
2022-03-25 14:15:53 +08:00
committed by GitHub
parent 920c5889a7
commit 6a3f9fda83
6 changed files with 253 additions and 143 deletions

View File

@@ -1,38 +1,7 @@
# BSD 3-Clause License
#
# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without modification,
# are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# * Neither the name of the psutil authors nor the names of its contributors
# may be used to endorse or promote products derived from this software without
# specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
# ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
# ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import math
import torch
try:
import cpu_adam
except ImportError:
raise ImportError("import cpu_adam error")
from colossalai.testing import parameterize
def torch_adam_update(
@@ -71,45 +40,46 @@ def torch_adam_update(
param.addcdiv_(exp_avg, denom, value=-step_size)
class Test():
def assertLess(data_diff, threshold, msg):
assert data_diff < threshold, msg
def __init__(self):
self.opt_id = 0
def assertLess(self, data_diff, threshold, msg):
assert data_diff < threshold, msg
def assertTrue(condition, msg):
assert condition, msg
def assertTrue(self, condition, msg):
assert condition, msg
def check_res(
self,
step,
lr,
eps,
beta1,
beta2,
weight_decay,
shape,
grad_dtype,
loss_scale,
use_adamw,
cpu_adam_op,
):
p_data = torch.rand(shape, dtype=grad_dtype)
@parameterize('adamw', [True, False])
@parameterize('step', [1, 2])
@parameterize('loss_scale', [-1, 2 ** 5])
@parameterize('p_dtype', [torch.float, torch.half])
@parameterize('g_dtype', [torch.float, torch.half])
def test_cpu_adam(adamw, step, loss_scale, p_dtype, g_dtype):
lr = 1e-3
beta1, beta2 = 0.9, 0.999
eps = 1e-8
weight_decay = 0
for i in range(1024):
p_data = torch.rand(64, dtype=p_dtype)
p_data_copy = p_data.clone().float()
p_grad = torch.rand(shape, dtype=grad_dtype)
p_grad = torch.rand(64, dtype=g_dtype)
if loss_scale > 0:
p_grad.mul_(loss_scale)
p_grad_copy = p_grad.clone().float()
exp_avg = torch.rand(shape)
exp_avg = torch.rand(p_data.shape)
exp_avg_copy = exp_avg.clone()
exp_avg_sq = torch.rand(shape)
exp_avg_sq = torch.rand(p_data.shape)
exp_avg_sq_copy = exp_avg_sq.clone()
cpu_adam_op.create_adam(0, lr, beta1, beta2, eps, weight_decay, use_adamw, True)
try:
import cpu_adam
cpu_adam_op = cpu_adam
except:
raise ImportError("...")
cpu_adam_op.create_adam(0, lr, beta1, beta2, eps, weight_decay, adamw, False)
cpu_adam_op.adam_update(
self.opt_id,
0,
step,
lr,
beta1,
@@ -136,62 +106,24 @@ class Test():
exp_avg_copy,
exp_avg_sq_copy,
loss_scale,
use_adamw,
adamw,
)
if loss_scale > 0:
p_grad.div_(loss_scale)
var = p_data_copy - p_data
data_diff = torch.max(torch.abs(var))
threshold = 2e-3 if grad_dtype else 1e-4
self.assertLess(
threshold = 1e-3
print(f"p_data diff {data_diff}. failed check, step {step}, lr {lr} eps "
f"{eps} beta1 {beta1} beta2 {beta2} weight_decay {weight_decay} p_dtype {p_dtype}, g_dtype {g_dtype}")
assertLess(
data_diff,
threshold,
f"p_data diff {data_diff}. failed check, step {step}, lr {lr} eps "
f"{eps} beta1 {beta1} beta2 {beta2} weight_decay {weight_decay} loss_scale {loss_scale} grad_dtype {grad_dtype}",
f"p_data diff {data_diff}. failed check, step {step}, lr {lr}, loss_scale {loss_scale}, eps "
f"{eps} beta1 {beta1} beta2 {beta2} weight_decay {weight_decay} p_dtype {p_dtype}, g_dtype {g_dtype}",
)
max_grad_diff = torch.max(torch.abs(p_grad_copy - p_grad))
self.assertTrue(max_grad_diff < threshold, f"diff {max_grad_diff}")
assertTrue(max_grad_diff < threshold, f"diff {max_grad_diff}")
max_exp_avg_diff = torch.max(torch.abs(exp_avg_copy - exp_avg))
self.assertTrue(max_exp_avg_diff < threshold, f"max_exp_avg_diff {max_exp_avg_diff}")
assertTrue(max_exp_avg_diff < threshold, f"max_exp_avg_diff {max_exp_avg_diff}")
max_exp_avg_sq_diff = torch.max(torch.abs(exp_avg_sq_copy - exp_avg_sq))
self.assertTrue(max_exp_avg_sq_diff < threshold, f"max_exp_avg_sq_diff {max_exp_avg_sq_diff}")
def test_cpu_adam(self):
lr = 0.9
eps = 1e-6
weight_decay = 0
for use_adamw in [False, True]:
for shape in [(23,), (8, 24)]:
for step in range(1, 2):
for lr in [0.01]:
for eps in [1e-8]:
for beta1 in [0.9]:
for beta2 in [0.999]:
for weight_decay in [0.001]:
for grad_dtype in [torch.half, torch.float]:
for loss_scale in [-1, 2**5]:
self.check_res(
step,
lr,
eps,
beta1,
beta2,
weight_decay,
shape,
grad_dtype,
loss_scale,
use_adamw,
cpu_adam,
)
def test_cpu_adam():
test_case = Test()
test_case.test_cpu_adam()
if __name__ == "__main__":
test = Test()
test.test_cpu_adam()
assertTrue(max_exp_avg_sq_diff < threshold, f"max_exp_avg_sq_diff {max_exp_avg_sq_diff}")