mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-25 11:44:03 +00:00
[optimizer] add div_scale for optimizers (#2117)
* [optimizer] add div_scale for optimizers * [zero] use div_scale in zero optimizer * fix testing error
This commit is contained in:
@@ -17,8 +17,8 @@ void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag,
|
||||
const float lr, const float beta1,
|
||||
const float beta2, const float epsilon,
|
||||
const int step, const int mode,
|
||||
const int bias_correction,
|
||||
const float weight_decay);
|
||||
const int bias_correction, const float weight_decay,
|
||||
const float div_scale);
|
||||
|
||||
void multi_tensor_lamb_cuda(int chunk_size, at::Tensor noop_flag,
|
||||
std::vector<std::vector<at::Tensor>> tensor_lists,
|
||||
@@ -46,4 +46,4 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
"Computes and apply update for LAMB optimizer");
|
||||
m.def("multi_tensor_l2norm", &multi_tensor_l2norm_cuda,
|
||||
"Computes L2 norm for a list of contiguous tensors");
|
||||
}
|
||||
}
|
||||
|
@@ -28,7 +28,7 @@ struct AdamFunctor {
|
||||
int chunk_size, volatile int *noop_gmem, TensorListMetadata<4> &tl,
|
||||
const float beta1, const float beta2, const float beta1_correction,
|
||||
const float beta2_correction, const float epsilon, const float lr,
|
||||
adamMode_t mode, const float decay) {
|
||||
adamMode_t mode, const float decay, const float div_scale) {
|
||||
// I'd like this kernel to propagate infs/nans.
|
||||
// if(*noop_gmem == 1)
|
||||
// return;
|
||||
@@ -79,6 +79,8 @@ struct AdamFunctor {
|
||||
}
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < ILP; ii++) {
|
||||
if (div_scale > 0) r_g[ii] /= div_scale;
|
||||
|
||||
if (mode == ADAM_MODE_0) { // L2
|
||||
r_g[ii] = r_g[ii] + (decay * r_p[ii]);
|
||||
r_m[ii] = beta1 * r_m[ii] + (1 - beta1) * r_g[ii];
|
||||
@@ -116,8 +118,8 @@ void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag,
|
||||
const float lr, const float beta1,
|
||||
const float beta2, const float epsilon,
|
||||
const int step, const int mode,
|
||||
const int bias_correction,
|
||||
const float weight_decay) {
|
||||
const int bias_correction, const float weight_decay,
|
||||
const float div_scale) {
|
||||
using namespace at;
|
||||
|
||||
// Handle bias correction mode
|
||||
@@ -133,7 +135,7 @@ void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag,
|
||||
multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
|
||||
AdamFunctor<g_scalar_t_0, p_scalar_t_0>(), beta1,
|
||||
beta2, bias_correction1, bias_correction2, epsilon,
|
||||
lr, (adamMode_t)mode, weight_decay);)
|
||||
lr, (adamMode_t)mode, weight_decay, div_scale);)
|
||||
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user