[zero] sharded optim support hybrid cpu adam (#486)

* sharded optim support hybrid cpu adam

* update unit test

* polish docstring
This commit is contained in:
ver217
2022-03-22 14:56:59 +08:00
committed by GitHub
parent b334822163
commit 62b0a8d644
5 changed files with 64 additions and 48 deletions

View File

@@ -5,6 +5,7 @@ import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from colossalai.amp import convert_to_apex_amp
from colossalai.nn.optimizer import CPUAdam
from colossalai.testing import parameterize
from colossalai.utils import free_port
@@ -18,7 +19,6 @@ from tests.components_to_test.registry import non_distributed_component_funcs
from torch.nn.parallel import DistributedDataParallel as DDP
from common import CONFIG, check_sharded_params_padding
from colossalai.amp import convert_to_apex_amp
def _run_step(model, optimizer, data, label, criterion, enable_autocast=False):
@@ -42,12 +42,15 @@ def _run_step(model, optimizer, data, label, criterion, enable_autocast=False):
@parameterize("cpu_offload", [True, False])
@parameterize("use_cpuadam", [True, False])
@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy])
def _run_test_sharded_optim_v2(cpu_offload, shard_strategy_class, use_cpuadam):
@parameterize("gpu_margin_mem_ratio", [0.0, 0.7])
def _run_test_sharded_optim_v2(cpu_offload, shard_strategy_class, use_cpuadam, gpu_margin_mem_ratio):
test_models = ['repeated_computed_layers', 'resnet18', 'bert']
shard_strategy = shard_strategy_class()
if use_cpuadam and cpu_offload is False:
return
if gpu_margin_mem_ratio > 0.0 and not (cpu_offload and use_cpuadam):
return
for model_name in test_models:
get_components_func = non_distributed_component_funcs.get_callable(model_name)
@@ -61,7 +64,8 @@ def _run_test_sharded_optim_v2(cpu_offload, shard_strategy_class, use_cpuadam):
zero_model = model_builder(checkpoint=True)
zero_model = ShardedModelV2(zero_model,
shard_strategy,
offload_config=dict(device='cpu') if cpu_offload else None)
offload_config=dict(device='cpu') if cpu_offload else None,
use_memory_tracer=gpu_margin_mem_ratio > 0.0)
model = model_builder(checkpoint=True).half()
col_model_deepcopy(zero_model, model)
@@ -71,7 +75,11 @@ def _run_test_sharded_optim_v2(cpu_offload, shard_strategy_class, use_cpuadam):
optimizer_class = CPUAdam
optim = optimizer_class(model.parameters(), lr=1e-3)
sharded_optim = optimizer_class(zero_model.parameters(), lr=1e-3)
sharded_optim = ShardedOptimizerV2(zero_model, sharded_optim, cpu_offload=cpu_offload, initial_scale=2**5)
sharded_optim = ShardedOptimizerV2(zero_model,
sharded_optim,
cpu_offload=cpu_offload,
initial_scale=2**5,
gpu_margin_mem_ratio=gpu_margin_mem_ratio)
amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False)
apex_model, apex_optimizer = convert_to_apex_amp(model, optim, amp_config)