mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 19:13:01 +00:00
fix bugs in CPU adam (#633)
* add cpu adam counter for all cpu adam * fixed updating error in adam kernel
This commit is contained in:
@@ -2,6 +2,7 @@ import torch
|
||||
|
||||
from colossalai.utils import multi_tensor_applier
|
||||
from colossalai.registry import OPTIMIZERS
|
||||
from colossalai.nn.optimizer import CPU_ADAM_CNT
|
||||
|
||||
|
||||
@OPTIMIZERS.register_module
|
||||
@@ -50,7 +51,6 @@ class HybridAdam(torch.optim.Optimizer):
|
||||
https://openreview.net/forum?id=ryQu7f-RZ
|
||||
"""
|
||||
|
||||
optimizer_id = 0
|
||||
# Number of fp32 shards for per parameter
|
||||
# Param weight, grad, momentum and variance
|
||||
num_fp32_shards_per_param = 4
|
||||
@@ -67,8 +67,7 @@ class HybridAdam(torch.optim.Optimizer):
|
||||
|
||||
default_args = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, bias_correction=bias_correction)
|
||||
super(HybridAdam, self).__init__(model_params, default_args)
|
||||
self.opt_id = HybridAdam.optimizer_id
|
||||
HybridAdam.optimizer_id = HybridAdam.optimizer_id + 1
|
||||
self.opt_id = CPU_ADAM_CNT()
|
||||
self.adamw_mode = adamw_mode
|
||||
try:
|
||||
import cpu_adam
|
||||
|
Reference in New Issue
Block a user