mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 01:28:31 +00:00
fix bugs in CPU adam (#633)
* add cpu adam counter for all cpu adam * fixed updating error in adam kernel
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
from functools import partial
|
||||
|
||||
import colossalai
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
@@ -51,11 +50,10 @@ def _run_step(model, optimizer, data, label, criterion, grad_handler):
|
||||
@parameterize("use_cpuadam", [True, False])
|
||||
@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy])
|
||||
def _run_test_sharded_optim_v2(cpu_offload, shard_strategy_class, use_cpuadam, gpu_margin_mem_ratio=0.0):
|
||||
MOE_CONTEXT.reset_loss()
|
||||
shard_strategy = shard_strategy_class()
|
||||
if use_cpuadam and cpu_offload is False:
|
||||
return
|
||||
|
||||
MOE_CONTEXT.reset_loss()
|
||||
get_components_func = non_distributed_component_funcs.get_callable('no_leaf_module')
|
||||
_, train_dataloader, _, optimizer_class, criterion = get_components_func()
|
||||
|
||||
|
Reference in New Issue
Block a user