mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 04:24:47 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -10,16 +10,25 @@ from torch import Tensor
|
||||
|
||||
from colossalai.utils import get_current_device, multi_tensor_applier
|
||||
|
||||
_FUSED_ALLOWED_P_G_TYPES = [(torch.float, torch.half), (torch.float, torch.float), (torch.half, torch.float),
|
||||
(torch.half, torch.half), (torch.bfloat16, torch.float), (torch.float, torch.bfloat16),
|
||||
(torch.bfloat16, torch.bfloat16)]
|
||||
_FUSED_ALLOWED_P_G_TYPES = [
|
||||
(torch.float, torch.half),
|
||||
(torch.float, torch.float),
|
||||
(torch.half, torch.float),
|
||||
(torch.half, torch.half),
|
||||
(torch.bfloat16, torch.float),
|
||||
(torch.float, torch.bfloat16),
|
||||
(torch.bfloat16, torch.bfloat16),
|
||||
]
|
||||
|
||||
_CPU_ALLOWED_P_G_TYPES = [(torch.float, torch.half), (torch.float, torch.float), (torch.half, torch.float),
|
||||
(torch.half, torch.half)]
|
||||
_CPU_ALLOWED_P_G_TYPES = [
|
||||
(torch.float, torch.half),
|
||||
(torch.float, torch.float),
|
||||
(torch.half, torch.float),
|
||||
(torch.half, torch.half),
|
||||
]
|
||||
|
||||
|
||||
class AdamKernel:
|
||||
|
||||
def __init__(self, lr: float, beta1: float, beta2: float, eps: float, weight_decay: float, use_adamw: bool) -> None:
|
||||
self.lr = lr
|
||||
self.beta1 = beta1
|
||||
@@ -34,7 +43,6 @@ class AdamKernel:
|
||||
|
||||
|
||||
class TorchAdamKernel(AdamKernel):
|
||||
|
||||
def update(self, step: int, param: Tensor, grad: Tensor, exp_avg: Tensor, exp_avg_sq: Tensor):
|
||||
bias_correction1 = 1 - self.beta1**step
|
||||
bias_correction2 = 1 - self.beta2**step
|
||||
@@ -57,36 +65,68 @@ class TorchAdamKernel(AdamKernel):
|
||||
|
||||
|
||||
class FusedAdamKernel(AdamKernel):
|
||||
|
||||
def __init__(self, lr: float, beta1: float, beta2: float, eps: float, weight_decay: float, use_adamw: bool) -> None:
|
||||
super().__init__(lr, beta1, beta2, eps, weight_decay, use_adamw)
|
||||
from colossalai.kernel.op_builder import FusedOptimBuilder
|
||||
|
||||
fused_optim = FusedOptimBuilder().load()
|
||||
self.fused_adam = fused_optim.multi_tensor_adam
|
||||
self.dummy_overflow_buf = torch.cuda.IntTensor([0])
|
||||
|
||||
def update(self, step: int, param: Tensor, grad: Tensor, exp_avg: Tensor, exp_avg_sq: Tensor):
|
||||
multi_tensor_applier(self.fused_adam, self.dummy_overflow_buf, [[grad], [param], [exp_avg], [exp_avg_sq]],
|
||||
self.lr, self.beta1, self.beta2, self.eps, step, self.use_adamw, True, self.weight_decay,
|
||||
-1)
|
||||
multi_tensor_applier(
|
||||
self.fused_adam,
|
||||
self.dummy_overflow_buf,
|
||||
[[grad], [param], [exp_avg], [exp_avg_sq]],
|
||||
self.lr,
|
||||
self.beta1,
|
||||
self.beta2,
|
||||
self.eps,
|
||||
step,
|
||||
self.use_adamw,
|
||||
True,
|
||||
self.weight_decay,
|
||||
-1,
|
||||
)
|
||||
|
||||
|
||||
class CPUAdamKernel(AdamKernel):
|
||||
|
||||
def __init__(self, lr: float, beta1: float, beta2: float, eps: float, weight_decay: float, use_adamw: bool) -> None:
|
||||
super().__init__(lr, beta1, beta2, eps, weight_decay, use_adamw)
|
||||
from colossalai.kernel.op_builder import CPUAdamBuilder
|
||||
|
||||
cpu_optim = CPUAdamBuilder().load()
|
||||
|
||||
self.cpu_adam_op = cpu_optim.CPUAdamOptimizer(lr, beta1, beta2, eps, weight_decay, use_adamw)
|
||||
|
||||
def update(self, step: int, param: Tensor, grad: Tensor, exp_avg: Tensor, exp_avg_sq: Tensor):
|
||||
self.cpu_adam_op.step(step, self.lr, self.beta1, self.beta2, self.eps, self.weight_decay, True, param.view(-1),
|
||||
grad.view(-1), exp_avg.view(-1), exp_avg_sq.view(-1), -1)
|
||||
self.cpu_adam_op.step(
|
||||
step,
|
||||
self.lr,
|
||||
self.beta1,
|
||||
self.beta2,
|
||||
self.eps,
|
||||
self.weight_decay,
|
||||
True,
|
||||
param.view(-1),
|
||||
grad.view(-1),
|
||||
exp_avg.view(-1),
|
||||
exp_avg_sq.view(-1),
|
||||
-1,
|
||||
)
|
||||
|
||||
|
||||
def check_adam_kernel(kernel: Type[AdamKernel], adamw: bool, weight_decay: float, p_dtype: torch.dtype,
|
||||
g_dtype: torch.dtype, device: torch.device, n_steps: int, rtol: float, atol: float):
|
||||
def check_adam_kernel(
|
||||
kernel: Type[AdamKernel],
|
||||
adamw: bool,
|
||||
weight_decay: float,
|
||||
p_dtype: torch.dtype,
|
||||
g_dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
n_steps: int,
|
||||
rtol: float,
|
||||
atol: float,
|
||||
):
|
||||
lr = 1e-3
|
||||
beta1, beta2 = 0.9, 0.999
|
||||
eps = 1e-8
|
||||
@@ -109,9 +149,9 @@ def check_adam_kernel(kernel: Type[AdamKernel], adamw: bool, weight_decay: float
|
||||
assert torch.allclose(master_p, p.float(), rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('adamw', [False, True])
|
||||
@pytest.mark.parametrize('weight_decay', [0.0, 0.1])
|
||||
@pytest.mark.parametrize('p_dtype, g_dtype', _FUSED_ALLOWED_P_G_TYPES)
|
||||
@pytest.mark.parametrize("adamw", [False, True])
|
||||
@pytest.mark.parametrize("weight_decay", [0.0, 0.1])
|
||||
@pytest.mark.parametrize("p_dtype, g_dtype", _FUSED_ALLOWED_P_G_TYPES)
|
||||
def test_fused_adam_kernel(adamw, weight_decay, p_dtype, g_dtype):
|
||||
rtol, atol = 1e-5, 1e-8
|
||||
if p_dtype is torch.float16 or g_dtype is torch.float16:
|
||||
@@ -121,11 +161,11 @@ def test_fused_adam_kernel(adamw, weight_decay, p_dtype, g_dtype):
|
||||
check_adam_kernel(FusedAdamKernel, adamw, weight_decay, p_dtype, g_dtype, get_current_device(), 3, rtol, atol)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('adamw', [False, True])
|
||||
@pytest.mark.parametrize('weight_decay', [0.0, 0.1])
|
||||
@pytest.mark.parametrize('p_dtype, g_dtype', _CPU_ALLOWED_P_G_TYPES)
|
||||
@pytest.mark.parametrize("adamw", [False, True])
|
||||
@pytest.mark.parametrize("weight_decay", [0.0, 0.1])
|
||||
@pytest.mark.parametrize("p_dtype, g_dtype", _CPU_ALLOWED_P_G_TYPES)
|
||||
def test_cpu_adam_kernel(adamw, weight_decay, p_dtype, g_dtype):
|
||||
rtol, atol = 1e-5, 1e-8
|
||||
if p_dtype is torch.float16 or g_dtype is torch.float16:
|
||||
rtol, atol = 1e-3, 1e-3
|
||||
check_adam_kernel(CPUAdamKernel, adamw, weight_decay, p_dtype, g_dtype, torch.device('cpu'), 3, rtol, atol)
|
||||
check_adam_kernel(CPUAdamKernel, adamw, weight_decay, p_dtype, g_dtype, torch.device("cpu"), 3, rtol, atol)
|
||||
|
Reference in New Issue
Block a user