fix bugs in CPU adam (#633)

* add cpu adam counter for all cpu adam

* fixed updating error in adam kernel
This commit is contained in:
HELSON
2022-04-02 17:04:05 +08:00
committed by GitHub
parent 1e2557e801
commit b31daed4cf
6 changed files with 25 additions and 13 deletions

View File

@@ -2,6 +2,7 @@ import math
import torch
from colossalai.registry import OPTIMIZERS
from colossalai.nn.optimizer import CPU_ADAM_CNT
@OPTIMIZERS.register_module
@@ -51,7 +52,6 @@ class CPUAdam(torch.optim.Optimizer):
https://openreview.net/forum?id=ryQu7f-RZ
"""
optimizer_id = 0
# Number of fp32 shards for per parameter
# Param weight, grad, momentum and variance
num_fp32_shards_per_param = 4
@@ -68,8 +68,7 @@ class CPUAdam(torch.optim.Optimizer):
default_args = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, bias_correction=bias_correction)
super(CPUAdam, self).__init__(model_params, default_args)
self.opt_id = CPUAdam.optimizer_id
CPUAdam.optimizer_id = CPUAdam.optimizer_id + 1
self.opt_id = CPU_ADAM_CNT()
self.adamw_mode = adamw_mode
try:
import cpu_adam
@@ -152,8 +151,8 @@ class CPUAdam(torch.optim.Optimizer):
assert state['exp_avg'].device.type == 'cuda', "exp_avg should stay on cuda"
assert state['exp_avg_sq'].device.type == 'cuda', "exp_avg should stay on cuda"
bias_correction1 = 1 - beta1**state['step']
bias_correction2 = 1 - beta2**state['step']
bias_correction1 = 1 - beta1 ** state['step']
bias_correction2 = 1 - beta2 ** state['step']
# adam on cuda
self.torch_adam_update(p.data, p.grad.data, state['exp_avg'], state['exp_avg_sq'], group['lr'],