From b31daed4cfae438bdb656439e240c6efb9f87494 Mon Sep 17 00:00:00 2001 From: HELSON Date: Sat, 2 Apr 2022 17:04:05 +0800 Subject: [PATCH] fix bugs in CPU adam (#633) * add cpu adam counter for all cpu adam * fixed updating error in adam kernel --- colossalai/kernel/cuda_native/csrc/cpu_adam.cpp | 2 +- colossalai/nn/optimizer/__init__.py | 4 +++- colossalai/nn/optimizer/cpu_adam.py | 9 ++++----- colossalai/nn/optimizer/hybrid_adam.py | 5 ++--- colossalai/nn/optimizer/utils.py | 14 ++++++++++++++ tests/test_moe/test_moe_zero_optim.py | 4 +--- 6 files changed, 25 insertions(+), 13 deletions(-) create mode 100644 colossalai/nn/optimizer/utils.py diff --git a/colossalai/kernel/cuda_native/csrc/cpu_adam.cpp b/colossalai/kernel/cuda_native/csrc/cpu_adam.cpp index efd569fcd..f26360659 100644 --- a/colossalai/kernel/cuda_native/csrc/cpu_adam.cpp +++ b/colossalai/kernel/cuda_native/csrc/cpu_adam.cpp @@ -493,7 +493,7 @@ int adam_step(int optimizer_id, grads_ptr, exp_avg_ptr, exp_avg_sq_ptr, - params_c.size(0), + params_c.numel(), (params.options().dtype() == at::kHalf), (grads.options().dtype() == at::kHalf), loss_scale); diff --git a/colossalai/nn/optimizer/__init__.py b/colossalai/nn/optimizer/__init__.py index 06072648b..c3f1127aa 100644 --- a/colossalai/nn/optimizer/__init__.py +++ b/colossalai/nn/optimizer/__init__.py @@ -1,3 +1,4 @@ +from .utils import CPU_ADAM_CNT from .colossalai_optimizer import ColossalaiOptimizer from .fused_adam import FusedAdam from .fused_lamb import FusedLAMB @@ -7,4 +8,5 @@ from .lars import Lars from .cpu_adam import CPUAdam from .hybrid_adam import HybridAdam -__all__ = ['ColossalaiOptimizer', 'FusedLAMB', 'FusedAdam', 'FusedSGD', 'Lamb', 'Lars', 'CPUAdam', 'HybridAdam'] +__all__ = ['ColossalaiOptimizer', 'FusedLAMB', 'FusedAdam', 'FusedSGD', + 'Lamb', 'Lars', 'CPUAdam', 'HybridAdam', 'CPU_ADAM_CNT'] diff --git a/colossalai/nn/optimizer/cpu_adam.py b/colossalai/nn/optimizer/cpu_adam.py index 1c6141fbb..475e615ea 100644 --- a/colossalai/nn/optimizer/cpu_adam.py +++ b/colossalai/nn/optimizer/cpu_adam.py @@ -2,6 +2,7 @@ import math import torch from colossalai.registry import OPTIMIZERS +from colossalai.nn.optimizer import CPU_ADAM_CNT @OPTIMIZERS.register_module @@ -51,7 +52,6 @@ class CPUAdam(torch.optim.Optimizer): https://openreview.net/forum?id=ryQu7f-RZ """ - optimizer_id = 0 # Number of fp32 shards for per parameter # Param weight, grad, momentum and variance num_fp32_shards_per_param = 4 @@ -68,8 +68,7 @@ class CPUAdam(torch.optim.Optimizer): default_args = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, bias_correction=bias_correction) super(CPUAdam, self).__init__(model_params, default_args) - self.opt_id = CPUAdam.optimizer_id - CPUAdam.optimizer_id = CPUAdam.optimizer_id + 1 + self.opt_id = CPU_ADAM_CNT() self.adamw_mode = adamw_mode try: import cpu_adam @@ -152,8 +151,8 @@ class CPUAdam(torch.optim.Optimizer): assert state['exp_avg'].device.type == 'cuda', "exp_avg should stay on cuda" assert state['exp_avg_sq'].device.type == 'cuda', "exp_avg should stay on cuda" - bias_correction1 = 1 - beta1**state['step'] - bias_correction2 = 1 - beta2**state['step'] + bias_correction1 = 1 - beta1 ** state['step'] + bias_correction2 = 1 - beta2 ** state['step'] # adam on cuda self.torch_adam_update(p.data, p.grad.data, state['exp_avg'], state['exp_avg_sq'], group['lr'], diff --git a/colossalai/nn/optimizer/hybrid_adam.py b/colossalai/nn/optimizer/hybrid_adam.py index df9e54c1b..58486c233 100644 --- a/colossalai/nn/optimizer/hybrid_adam.py +++ b/colossalai/nn/optimizer/hybrid_adam.py @@ -2,6 +2,7 @@ import torch from colossalai.utils import multi_tensor_applier from colossalai.registry import OPTIMIZERS +from colossalai.nn.optimizer import CPU_ADAM_CNT @OPTIMIZERS.register_module @@ -50,7 +51,6 @@ class HybridAdam(torch.optim.Optimizer): https://openreview.net/forum?id=ryQu7f-RZ """ - optimizer_id = 0 # Number of fp32 shards for per parameter # Param weight, grad, momentum and variance num_fp32_shards_per_param = 4 @@ -67,8 +67,7 @@ class HybridAdam(torch.optim.Optimizer): default_args = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, bias_correction=bias_correction) super(HybridAdam, self).__init__(model_params, default_args) - self.opt_id = HybridAdam.optimizer_id - HybridAdam.optimizer_id = HybridAdam.optimizer_id + 1 + self.opt_id = CPU_ADAM_CNT() self.adamw_mode = adamw_mode try: import cpu_adam diff --git a/colossalai/nn/optimizer/utils.py b/colossalai/nn/optimizer/utils.py new file mode 100644 index 000000000..513c169e5 --- /dev/null +++ b/colossalai/nn/optimizer/utils.py @@ -0,0 +1,14 @@ +class CpuAdamCounter(object): + """Used to record the total number of CPU Adam. + We must use it to avoid hybrid cpu adam and cpu adam using the same id. + """ + + def __init__(self): + self.number = 0 + + def __call__(self): + self.number += 1 + return self.number - 1 + + +CPU_ADAM_CNT = CpuAdamCounter() diff --git a/tests/test_moe/test_moe_zero_optim.py b/tests/test_moe/test_moe_zero_optim.py index 91004545f..443b4ba50 100644 --- a/tests/test_moe/test_moe_zero_optim.py +++ b/tests/test_moe/test_moe_zero_optim.py @@ -1,7 +1,6 @@ from functools import partial import colossalai -from colossalai.utils.cuda import get_current_device import pytest import torch import torch.multiprocessing as mp @@ -51,11 +50,10 @@ def _run_step(model, optimizer, data, label, criterion, grad_handler): @parameterize("use_cpuadam", [True, False]) @parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy]) def _run_test_sharded_optim_v2(cpu_offload, shard_strategy_class, use_cpuadam, gpu_margin_mem_ratio=0.0): - MOE_CONTEXT.reset_loss() shard_strategy = shard_strategy_class() if use_cpuadam and cpu_offload is False: return - + MOE_CONTEXT.reset_loss() get_components_func = non_distributed_component_funcs.get_callable('no_leaf_module') _, train_dataloader, _, optimizer_class, criterion = get_components_func()