mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-22 05:29:36 +00:00
[hotfix] fix CPUAdam kernel nullptr (#1410)
This commit is contained in:
parent
1e5eb0874c
commit
12b4887097
@ -24,15 +24,12 @@ SOFTWARE
|
|||||||
#include <math.h>
|
#include <math.h>
|
||||||
#include <omp.h>
|
#include <omp.h>
|
||||||
#include <string.h>
|
#include <string.h>
|
||||||
#include <torch/extension.h>
|
|
||||||
|
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <type_traits>
|
#include <type_traits>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
|
|
||||||
static std::unordered_map<int, std::shared_ptr<void>> s_optimizers;
|
|
||||||
|
|
||||||
// C++ interface
|
// C++ interface
|
||||||
|
|
||||||
void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg,
|
void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg,
|
||||||
@ -310,35 +307,6 @@ void Adam_Optimizer::Step_4(float *_params, float *grads, float *_exp_avg,
|
|||||||
grad_half_precision, loss_scale);
|
grad_half_precision, loss_scale);
|
||||||
}
|
}
|
||||||
|
|
||||||
int create_adam_optimizer(int optimizer_id, float alpha = 1e-3,
|
|
||||||
float betta1 = 0.9, float betta2 = 0.999,
|
|
||||||
float eps = 1e-8, float weight_decay = 0,
|
|
||||||
bool adamw_mode = true, bool should_log = false) {
|
|
||||||
auto opt = std::make_shared<Adam_Optimizer>(alpha, betta1, betta2, eps,
|
|
||||||
weight_decay, adamw_mode);
|
|
||||||
|
|
||||||
s_optimizers[optimizer_id] = opt;
|
|
||||||
|
|
||||||
if (should_log) {
|
|
||||||
std::string avx_type = "";
|
|
||||||
#if defined(__AVX512__)
|
|
||||||
avx_type = "AVX512";
|
|
||||||
#else
|
|
||||||
#if defined(__AVX256__) or defined(__AVX2__)
|
|
||||||
avx_type = "AVX2";
|
|
||||||
#else
|
|
||||||
avx_type = "scalar";
|
|
||||||
#endif
|
|
||||||
#endif
|
|
||||||
printf("Adam Optimizer #%d is created with %s arithmetic capability.\n",
|
|
||||||
optimizer_id, avx_type.c_str());
|
|
||||||
printf("Config: alpha=%f, betas=(%f, %f), weight_decay=%f, adam_w=%d\n",
|
|
||||||
alpha, betta1, betta2, weight_decay, (int)adamw_mode);
|
|
||||||
}
|
|
||||||
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
void Adam_Optimizer::Step_8(float *_params, float *grads, float *_exp_avg,
|
void Adam_Optimizer::Step_8(float *_params, float *grads, float *_exp_avg,
|
||||||
float *_exp_avg_sq, size_t _param_size,
|
float *_exp_avg_sq, size_t _param_size,
|
||||||
bool param_half_precision, bool grad_half_precision,
|
bool param_half_precision, bool grad_half_precision,
|
||||||
@ -460,11 +428,11 @@ void Adam_Optimizer::Step_8(float *_params, float *grads, float *_exp_avg,
|
|||||||
grad_half_precision, loss_scale);
|
grad_half_precision, loss_scale);
|
||||||
}
|
}
|
||||||
|
|
||||||
int adam_step(int optimizer_id, size_t step, float lr, float beta1, float beta2,
|
void Adam_Optimizer::step(size_t step, float lr, float beta1, float beta2,
|
||||||
float epsilon, float weight_decay, bool bias_correction,
|
float epsilon, float weight_decay,
|
||||||
torch::Tensor ¶ms, torch::Tensor &grads,
|
bool bias_correction, torch::Tensor ¶ms,
|
||||||
torch::Tensor &exp_avg, torch::Tensor &exp_avg_sq,
|
torch::Tensor &grads, torch::Tensor &exp_avg,
|
||||||
float loss_scale) {
|
torch::Tensor &exp_avg_sq, float loss_scale) {
|
||||||
auto params_c = params.contiguous();
|
auto params_c = params.contiguous();
|
||||||
auto grads_c = grads.contiguous();
|
auto grads_c = grads.contiguous();
|
||||||
auto exp_avg_c = exp_avg.contiguous();
|
auto exp_avg_c = exp_avg.contiguous();
|
||||||
@ -474,24 +442,18 @@ int adam_step(int optimizer_id, size_t step, float lr, float beta1, float beta2,
|
|||||||
float *grads_ptr = (float *)grads_c.data_ptr();
|
float *grads_ptr = (float *)grads_c.data_ptr();
|
||||||
float *exp_avg_ptr = (float *)exp_avg_c.data_ptr();
|
float *exp_avg_ptr = (float *)exp_avg_c.data_ptr();
|
||||||
float *exp_avg_sq_ptr = (float *)exp_avg_sq_c.data_ptr();
|
float *exp_avg_sq_ptr = (float *)exp_avg_sq_c.data_ptr();
|
||||||
std::shared_ptr<Adam_Optimizer> opt =
|
|
||||||
std::static_pointer_cast<Adam_Optimizer>(s_optimizers[optimizer_id]);
|
|
||||||
opt->IncrementStep(step, beta1, beta2);
|
|
||||||
opt->update_state(lr, epsilon, weight_decay, bias_correction);
|
|
||||||
opt->Step_8(params_ptr, grads_ptr, exp_avg_ptr, exp_avg_sq_ptr,
|
|
||||||
params_c.numel(), (params.options().dtype() == at::kHalf),
|
|
||||||
(grads.options().dtype() == at::kHalf), loss_scale);
|
|
||||||
|
|
||||||
return 0;
|
this->IncrementStep(step, beta1, beta2);
|
||||||
|
this->update_state(lr, epsilon, weight_decay, bias_correction);
|
||||||
|
this->Step_8(params_ptr, grads_ptr, exp_avg_ptr, exp_avg_sq_ptr,
|
||||||
|
params_c.numel(), (params.options().dtype() == at::kHalf),
|
||||||
|
(grads.options().dtype() == at::kHalf), loss_scale);
|
||||||
}
|
}
|
||||||
|
|
||||||
int destroy_adam_optimizer(int optimizer_id) {
|
namespace py = pybind11;
|
||||||
s_optimizers.erase(optimizer_id);
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||||
m.def("adam_update", &adam_step, "CPU Adam update (C++)");
|
py::class_<Adam_Optimizer>(m, "CPUAdamOptimizer")
|
||||||
m.def("create_adam", &create_adam_optimizer, "CPU Adam (C++)");
|
.def(py::init<float, float, float, float, float, bool>())
|
||||||
m.def("destroy_adam", &destroy_adam_optimizer, "CPU Adam destroy (C++)");
|
.def("step", &Adam_Optimizer::step);
|
||||||
}
|
}
|
||||||
|
@ -26,7 +26,7 @@ SOFTWARE
|
|||||||
#include <cuda_fp16.h>
|
#include <cuda_fp16.h>
|
||||||
#include <cuda_runtime_api.h>
|
#include <cuda_runtime_api.h>
|
||||||
#include <stdio.h>
|
#include <stdio.h>
|
||||||
|
#include <torch/extension.h>
|
||||||
#if (__x86_64__ || __i386__)
|
#if (__x86_64__ || __i386__)
|
||||||
#include <cpuid.h>
|
#include <cpuid.h>
|
||||||
#include <x86intrin.h>
|
#include <x86intrin.h>
|
||||||
@ -141,6 +141,11 @@ class Adam_Optimizer {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void step(size_t step, float lr, float beta1, float beta2, float epsilon,
|
||||||
|
float weight_decay, bool bias_correction, torch::Tensor ¶ms,
|
||||||
|
torch::Tensor &grads, torch::Tensor &exp_avg,
|
||||||
|
torch::Tensor &exp_avg_sq, float loss_scale);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
float _alpha;
|
float _alpha;
|
||||||
float _betta1;
|
float _betta1;
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
from .utils import CPU_ADAM_CNT
|
|
||||||
from .colossalai_optimizer import ColossalaiOptimizer
|
from .colossalai_optimizer import ColossalaiOptimizer
|
||||||
from .fused_adam import FusedAdam
|
from .fused_adam import FusedAdam
|
||||||
from .fused_lamb import FusedLAMB
|
from .fused_lamb import FusedLAMB
|
||||||
@ -8,6 +7,4 @@ from .lars import Lars
|
|||||||
from .cpu_adam import CPUAdam
|
from .cpu_adam import CPUAdam
|
||||||
from .hybrid_adam import HybridAdam
|
from .hybrid_adam import HybridAdam
|
||||||
|
|
||||||
__all__ = [
|
__all__ = ['ColossalaiOptimizer', 'FusedLAMB', 'FusedAdam', 'FusedSGD', 'Lamb', 'Lars', 'CPUAdam', 'HybridAdam']
|
||||||
'ColossalaiOptimizer', 'FusedLAMB', 'FusedAdam', 'FusedSGD', 'Lamb', 'Lars', 'CPUAdam', 'HybridAdam', 'CPU_ADAM_CNT'
|
|
||||||
]
|
|
||||||
|
@ -2,7 +2,6 @@ import math
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from colossalai.registry import OPTIMIZERS
|
from colossalai.registry import OPTIMIZERS
|
||||||
from colossalai.nn.optimizer import CPU_ADAM_CNT
|
|
||||||
from .nvme_optimizer import NVMeOptimizer
|
from .nvme_optimizer import NVMeOptimizer
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
@ -69,25 +68,17 @@ class CPUAdam(NVMeOptimizer):
|
|||||||
eps=1e-8,
|
eps=1e-8,
|
||||||
weight_decay=0,
|
weight_decay=0,
|
||||||
adamw_mode=True,
|
adamw_mode=True,
|
||||||
simd_log=False,
|
|
||||||
nvme_offload_fraction: float = 0.0,
|
nvme_offload_fraction: float = 0.0,
|
||||||
nvme_offload_dir: Optional[str] = None):
|
nvme_offload_dir: Optional[str] = None):
|
||||||
|
|
||||||
default_args = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, bias_correction=bias_correction)
|
default_args = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, bias_correction=bias_correction)
|
||||||
super(CPUAdam, self).__init__(model_params, default_args, nvme_offload_fraction, nvme_offload_dir)
|
super(CPUAdam, self).__init__(model_params, default_args, nvme_offload_fraction, nvme_offload_dir)
|
||||||
self.opt_id = CPU_ADAM_CNT()
|
|
||||||
self.adamw_mode = adamw_mode
|
self.adamw_mode = adamw_mode
|
||||||
try:
|
try:
|
||||||
import cpu_adam
|
import cpu_adam
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError('Please install colossalai from source code to use CPUAdam')
|
raise ImportError('Please install colossalai from source code to use CPUAdam')
|
||||||
self.cpu_adam_op = cpu_adam
|
self.cpu_adam_op = cpu_adam.CPUAdamOptimizer(lr, betas[0], betas[1], eps, weight_decay, adamw_mode)
|
||||||
self.cpu_adam_op.create_adam(self.opt_id, lr, betas[0], betas[1], eps, weight_decay, adamw_mode, simd_log)
|
|
||||||
|
|
||||||
def __del__(self):
|
|
||||||
super().__del__()
|
|
||||||
if getattr(self, 'cpu_adam_op', None):
|
|
||||||
self.cpu_adam_op.destroy_adam(self.opt_id)
|
|
||||||
|
|
||||||
def torch_adam_update(self,
|
def torch_adam_update(self,
|
||||||
data,
|
data,
|
||||||
@ -156,9 +147,9 @@ class CPUAdam(NVMeOptimizer):
|
|||||||
assert state['exp_avg'].device.type == 'cpu', "exp_avg should stay on cpu"
|
assert state['exp_avg'].device.type == 'cpu', "exp_avg should stay on cpu"
|
||||||
assert state['exp_avg_sq'].device.type == 'cpu', "exp_avg should stay on cpu"
|
assert state['exp_avg_sq'].device.type == 'cpu', "exp_avg should stay on cpu"
|
||||||
self._pre_update(p, 'exp_avg', 'exp_avg_sq')
|
self._pre_update(p, 'exp_avg', 'exp_avg_sq')
|
||||||
self.cpu_adam_op.adam_update(self.opt_id, state['step'], group['lr'], beta1, beta2, group['eps'],
|
self.cpu_adam_op.step(state['step'], group['lr'], beta1, beta2, group['eps'], group['weight_decay'],
|
||||||
group['weight_decay'], group['bias_correction'], p.data, p.grad.data,
|
group['bias_correction'], p.data, p.grad.data, state['exp_avg'],
|
||||||
state['exp_avg'], state['exp_avg_sq'], -1)
|
state['exp_avg_sq'], -1)
|
||||||
self._post_update(p, 'exp_avg', 'exp_avg_sq')
|
self._post_update(p, 'exp_avg', 'exp_avg_sq')
|
||||||
elif target_device.type == 'cuda':
|
elif target_device.type == 'cuda':
|
||||||
assert state['exp_avg'].device.type == 'cuda', "exp_avg should stay on cuda"
|
assert state['exp_avg'].device.type == 'cuda', "exp_avg should stay on cuda"
|
||||||
|
@ -2,7 +2,6 @@ import torch
|
|||||||
|
|
||||||
from colossalai.utils import multi_tensor_applier
|
from colossalai.utils import multi_tensor_applier
|
||||||
from colossalai.registry import OPTIMIZERS
|
from colossalai.registry import OPTIMIZERS
|
||||||
from colossalai.nn.optimizer import CPU_ADAM_CNT
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from .nvme_optimizer import NVMeOptimizer
|
from .nvme_optimizer import NVMeOptimizer
|
||||||
|
|
||||||
@ -68,13 +67,11 @@ class HybridAdam(NVMeOptimizer):
|
|||||||
eps=1e-8,
|
eps=1e-8,
|
||||||
weight_decay=0,
|
weight_decay=0,
|
||||||
adamw_mode=True,
|
adamw_mode=True,
|
||||||
simd_log=False,
|
|
||||||
nvme_offload_fraction: float = 0.0,
|
nvme_offload_fraction: float = 0.0,
|
||||||
nvme_offload_dir: Optional[str] = None):
|
nvme_offload_dir: Optional[str] = None):
|
||||||
|
|
||||||
default_args = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, bias_correction=bias_correction)
|
default_args = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, bias_correction=bias_correction)
|
||||||
super(HybridAdam, self).__init__(model_params, default_args, nvme_offload_fraction, nvme_offload_dir)
|
super(HybridAdam, self).__init__(model_params, default_args, nvme_offload_fraction, nvme_offload_dir)
|
||||||
self.opt_id = CPU_ADAM_CNT()
|
|
||||||
self.adamw_mode = adamw_mode
|
self.adamw_mode = adamw_mode
|
||||||
try:
|
try:
|
||||||
import cpu_adam
|
import cpu_adam
|
||||||
@ -82,17 +79,11 @@ class HybridAdam(NVMeOptimizer):
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError('Please install colossalai from source code to use HybridAdam')
|
raise ImportError('Please install colossalai from source code to use HybridAdam')
|
||||||
|
|
||||||
self.cpu_adam_op = cpu_adam
|
self.cpu_adam_op = cpu_adam.CPUAdamOptimizer(lr, betas[0], betas[1], eps, weight_decay, adamw_mode)
|
||||||
self.cpu_adam_op.create_adam(self.opt_id, lr, betas[0], betas[1], eps, weight_decay, adamw_mode, simd_log)
|
|
||||||
|
|
||||||
self.gpu_adam_op = colossal_C.multi_tensor_adam
|
self.gpu_adam_op = colossal_C.multi_tensor_adam
|
||||||
self._dummy_overflow_buf = torch.cuda.IntTensor([0])
|
self._dummy_overflow_buf = torch.cuda.IntTensor([0])
|
||||||
|
|
||||||
def __del__(self):
|
|
||||||
super().__del__()
|
|
||||||
if getattr(self, 'cpu_adam_op', None):
|
|
||||||
self.cpu_adam_op.destroy_adam(self.opt_id)
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def step(self, closure=None):
|
def step(self, closure=None):
|
||||||
loss = None
|
loss = None
|
||||||
@ -129,9 +120,9 @@ class HybridAdam(NVMeOptimizer):
|
|||||||
assert state['exp_avg'].device.type == 'cpu', "exp_avg should stay on cpu"
|
assert state['exp_avg'].device.type == 'cpu', "exp_avg should stay on cpu"
|
||||||
assert state['exp_avg_sq'].device.type == 'cpu', "exp_avg should stay on cpu"
|
assert state['exp_avg_sq'].device.type == 'cpu', "exp_avg should stay on cpu"
|
||||||
self._pre_update(p, 'exp_avg', 'exp_avg_sq')
|
self._pre_update(p, 'exp_avg', 'exp_avg_sq')
|
||||||
self.cpu_adam_op.adam_update(self.opt_id, state['step'], group['lr'], beta1, beta2, group['eps'],
|
self.cpu_adam_op.step(state['step'], group['lr'], beta1, beta2, group['eps'], group['weight_decay'],
|
||||||
group['weight_decay'], group['bias_correction'], p.data, p.grad.data,
|
group['bias_correction'], p.data, p.grad.data, state['exp_avg'],
|
||||||
state['exp_avg'], state['exp_avg_sq'], -1)
|
state['exp_avg_sq'], -1)
|
||||||
self._post_update(p, 'exp_avg', 'exp_avg_sq')
|
self._post_update(p, 'exp_avg', 'exp_avg_sq')
|
||||||
|
|
||||||
elif target_device.type == 'cuda':
|
elif target_device.type == 'cuda':
|
||||||
|
@ -1,14 +0,0 @@
|
|||||||
class CpuAdamCounter(object):
|
|
||||||
"""Used to record the total number of CPU Adam.
|
|
||||||
We must use it to avoid hybrid cpu adam and cpu adam using the same id.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.number = 0
|
|
||||||
|
|
||||||
def __call__(self):
|
|
||||||
self.number += 1
|
|
||||||
return self.number - 1
|
|
||||||
|
|
||||||
|
|
||||||
CPU_ADAM_CNT = CpuAdamCounter()
|
|
@ -67,13 +67,11 @@ def test_cpu_adam(adamw, step, p_dtype, g_dtype):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
import cpu_adam
|
import cpu_adam
|
||||||
cpu_adam_op = cpu_adam
|
cpu_adam_op = cpu_adam.CPUAdamOptimizer(lr, beta1, beta2, eps, weight_decay, adamw)
|
||||||
except:
|
except:
|
||||||
raise ImportError("Import cpu adam error, please install colossal from source code")
|
raise ImportError("Import cpu adam error, please install colossal from source code")
|
||||||
|
|
||||||
cpu_adam_op.create_adam(0, lr, beta1, beta2, eps, weight_decay, adamw, False)
|
cpu_adam_op.step(
|
||||||
cpu_adam_op.adam_update(
|
|
||||||
0,
|
|
||||||
step,
|
step,
|
||||||
lr,
|
lr,
|
||||||
beta1,
|
beta1,
|
@ -8,9 +8,11 @@ from colossalai.testing import parameterize
|
|||||||
|
|
||||||
|
|
||||||
class FC(nn.Module):
|
class FC(nn.Module):
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.fc = nn.Sequential(nn.Linear(64, 64))
|
self.fc = nn.Sequential(nn.Linear(64, 64))
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return self.fc(x)
|
return self.fc(x)
|
||||||
|
|
||||||
@ -37,7 +39,7 @@ def test_adam(adamw, p_dtype, g_dtype):
|
|||||||
|
|
||||||
for d, l in zip(data, label):
|
for d, l in zip(data, label):
|
||||||
y = model(d)
|
y = model(d)
|
||||||
loss = ((l - y) ** 2).sum()
|
loss = ((l - y)**2).sum()
|
||||||
optim.zero_grad()
|
optim.zero_grad()
|
||||||
loss.backward()
|
loss.backward()
|
||||||
if p_dtype != g_dtype:
|
if p_dtype != g_dtype:
|
||||||
@ -47,7 +49,7 @@ def test_adam(adamw, p_dtype, g_dtype):
|
|||||||
|
|
||||||
for d, l in zip(data_copy, label):
|
for d, l in zip(data_copy, label):
|
||||||
y = model_copy(d)
|
y = model_copy(d)
|
||||||
loss = ((l - y) ** 2).sum()
|
loss = ((l - y)**2).sum()
|
||||||
torch_optim.zero_grad()
|
torch_optim.zero_grad()
|
||||||
loss.backward()
|
loss.backward()
|
||||||
torch_optim.step()
|
torch_optim.step()
|
@ -7,6 +7,7 @@ import math
|
|||||||
from colossalai.testing import parameterize
|
from colossalai.testing import parameterize
|
||||||
from colossalai.utils import multi_tensor_applier
|
from colossalai.utils import multi_tensor_applier
|
||||||
|
|
||||||
|
|
||||||
def torch_adam_update(
|
def torch_adam_update(
|
||||||
step,
|
step,
|
||||||
lr,
|
lr,
|
||||||
@ -69,26 +70,26 @@ def test_adam(adamw, step, p_dtype, g_dtype):
|
|||||||
eps = 1e-8
|
eps = 1e-8
|
||||||
weight_decay = 0
|
weight_decay = 0
|
||||||
|
|
||||||
multi_tensor_applier(fused_adam, dummy_overflow_buf, [[g], [p], [m], [v]],
|
multi_tensor_applier(fused_adam, dummy_overflow_buf, [[g], [p], [m], [v]], lr, beta1, beta2, eps, step, adamw,
|
||||||
lr, beta1, beta2, eps, step, adamw,
|
True, weight_decay)
|
||||||
True, weight_decay)
|
|
||||||
|
|
||||||
torch_adam_update(
|
torch_adam_update(
|
||||||
step,
|
step,
|
||||||
lr,
|
lr,
|
||||||
beta1,
|
beta1,
|
||||||
beta2,
|
beta2,
|
||||||
eps,
|
eps,
|
||||||
weight_decay,
|
weight_decay,
|
||||||
p_copy, # fp32 data
|
p_copy, # fp32 data
|
||||||
g_copy, # fp32 grad
|
g_copy, # fp32 grad
|
||||||
m_copy,
|
m_copy,
|
||||||
v_copy,
|
v_copy,
|
||||||
adamw,
|
adamw,
|
||||||
)
|
)
|
||||||
|
|
||||||
if torch.isnan(p).any() or torch.isnan(p_copy).any():
|
if torch.isnan(p).any() or torch.isnan(p_copy).any():
|
||||||
count += 1
|
count += 1
|
||||||
continue
|
continue
|
||||||
assert count < 200, "too many nans"
|
assert count < 200, "too many nans"
|
||||||
assert torch.allclose(p.to(torch.float), p_copy.to(torch.float), 1e-5, 1e-5), f"failed check, adamw {adamw}, p_dtype {p_dtype}, g_dtype {g_dtype}"
|
assert torch.allclose(p.to(torch.float), p_copy.to(torch.float), 1e-5,
|
||||||
|
1e-5), f"failed check, adamw {adamw}, p_dtype {p_dtype}, g_dtype {g_dtype}"
|
@ -38,4 +38,4 @@ def test_adam(adamw, device, p_dtype, g_dtype):
|
|||||||
if torch.isnan(p.data).any() or torch.isnan(p_copy.data).any():
|
if torch.isnan(p.data).any() or torch.isnan(p_copy.data).any():
|
||||||
continue
|
continue
|
||||||
assert torch.allclose(p.data, p_copy.data, 1e-4, 1e-2), \
|
assert torch.allclose(p.data, p_copy.data, 1e-4, 1e-2), \
|
||||||
f"adaw mode {adamw}, device {device}, p_dtype {p_dtype}, g_dtype {g_dtype}"
|
f"adaw mode {adamw}, device {device}, p_dtype {p_dtype}, g_dtype {g_dtype}"
|
Loading…
Reference in New Issue
Block a user