fix bugs in CPU adam (#633)

* add cpu adam counter for all cpu adam

* fixed updating error in adam kernel
This commit is contained in:
HELSON
2022-04-02 17:04:05 +08:00
committed by GitHub
parent 1e2557e801
commit b31daed4cf
6 changed files with 25 additions and 13 deletions

View File

@@ -1,7 +1,6 @@
from functools import partial
import colossalai
from colossalai.utils.cuda import get_current_device
import pytest
import torch
import torch.multiprocessing as mp
@@ -51,11 +50,10 @@ def _run_step(model, optimizer, data, label, criterion, grad_handler):
@parameterize("use_cpuadam", [True, False])
@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy])
def _run_test_sharded_optim_v2(cpu_offload, shard_strategy_class, use_cpuadam, gpu_margin_mem_ratio=0.0):
MOE_CONTEXT.reset_loss()
shard_strategy = shard_strategy_class()
if use_cpuadam and cpu_offload is False:
return
MOE_CONTEXT.reset_loss()
get_components_func = non_distributed_component_funcs.get_callable('no_leaf_module')
_, train_dataloader, _, optimizer_class, criterion = get_components_func()