[optimizer] add div_scale for optimizers (#2117)

* [optimizer] add div_scale for optimizers

* [zero] use div_scale in zero optimizer

* fix testing error
This commit is contained in:
HELSON
2022-12-12 17:58:57 +08:00
committed by GitHub
parent e5aa8333e4
commit e7d3afc9cc
8 changed files with 41 additions and 32 deletions

View File

@@ -17,8 +17,8 @@ void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag,
const float lr, const float beta1,
const float beta2, const float epsilon,
const int step, const int mode,
const int bias_correction,
const float weight_decay);
const int bias_correction, const float weight_decay,
const float div_scale);
void multi_tensor_lamb_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
@@ -46,4 +46,4 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"Computes and apply update for LAMB optimizer");
m.def("multi_tensor_l2norm", &multi_tensor_l2norm_cuda,
"Computes L2 norm for a list of contiguous tensors");
}
}

View File

@@ -28,7 +28,7 @@ struct AdamFunctor {
int chunk_size, volatile int *noop_gmem, TensorListMetadata<4> &tl,
const float beta1, const float beta2, const float beta1_correction,
const float beta2_correction, const float epsilon, const float lr,
adamMode_t mode, const float decay) {
adamMode_t mode, const float decay, const float div_scale) {
// I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1)
// return;
@@ -79,6 +79,8 @@ struct AdamFunctor {
}
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
if (div_scale > 0) r_g[ii] /= div_scale;
if (mode == ADAM_MODE_0) { // L2
r_g[ii] = r_g[ii] + (decay * r_p[ii]);
r_m[ii] = beta1 * r_m[ii] + (1 - beta1) * r_g[ii];
@@ -116,8 +118,8 @@ void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag,
const float lr, const float beta1,
const float beta2, const float epsilon,
const int step, const int mode,
const int bias_correction,
const float weight_decay) {
const int bias_correction, const float weight_decay,
const float div_scale) {
using namespace at;
// Handle bias correction mode
@@ -133,7 +135,7 @@ void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag,
multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
AdamFunctor<g_scalar_t_0, p_scalar_t_0>(), beta1,
beta2, bias_correction1, bias_correction2, epsilon,
lr, (adamMode_t)mode, weight_decay);)
lr, (adamMode_t)mode, weight_decay, div_scale);)
AT_CUDA_CHECK(cudaGetLastError());
}
}