mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 17:17:05 +00:00
[zero] sharded optim support hybrid cpu adam (#486)
* sharded optim support hybrid cpu adam * update unit test * polish docstring
This commit is contained in:
@@ -5,6 +5,7 @@ import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
from colossalai.amp import convert_to_apex_amp
|
||||
from colossalai.nn.optimizer import CPUAdam
|
||||
from colossalai.testing import parameterize
|
||||
from colossalai.utils import free_port
|
||||
@@ -18,7 +19,6 @@ from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
from common import CONFIG, check_sharded_params_padding
|
||||
from colossalai.amp import convert_to_apex_amp
|
||||
|
||||
|
||||
def _run_step(model, optimizer, data, label, criterion, enable_autocast=False):
|
||||
@@ -42,12 +42,15 @@ def _run_step(model, optimizer, data, label, criterion, enable_autocast=False):
|
||||
@parameterize("cpu_offload", [True, False])
|
||||
@parameterize("use_cpuadam", [True, False])
|
||||
@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy])
|
||||
def _run_test_sharded_optim_v2(cpu_offload, shard_strategy_class, use_cpuadam):
|
||||
@parameterize("gpu_margin_mem_ratio", [0.0, 0.7])
|
||||
def _run_test_sharded_optim_v2(cpu_offload, shard_strategy_class, use_cpuadam, gpu_margin_mem_ratio):
|
||||
test_models = ['repeated_computed_layers', 'resnet18', 'bert']
|
||||
shard_strategy = shard_strategy_class()
|
||||
|
||||
if use_cpuadam and cpu_offload is False:
|
||||
return
|
||||
if gpu_margin_mem_ratio > 0.0 and not (cpu_offload and use_cpuadam):
|
||||
return
|
||||
|
||||
for model_name in test_models:
|
||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||
@@ -61,7 +64,8 @@ def _run_test_sharded_optim_v2(cpu_offload, shard_strategy_class, use_cpuadam):
|
||||
zero_model = model_builder(checkpoint=True)
|
||||
zero_model = ShardedModelV2(zero_model,
|
||||
shard_strategy,
|
||||
offload_config=dict(device='cpu') if cpu_offload else None)
|
||||
offload_config=dict(device='cpu') if cpu_offload else None,
|
||||
use_memory_tracer=gpu_margin_mem_ratio > 0.0)
|
||||
|
||||
model = model_builder(checkpoint=True).half()
|
||||
col_model_deepcopy(zero_model, model)
|
||||
@@ -71,7 +75,11 @@ def _run_test_sharded_optim_v2(cpu_offload, shard_strategy_class, use_cpuadam):
|
||||
optimizer_class = CPUAdam
|
||||
optim = optimizer_class(model.parameters(), lr=1e-3)
|
||||
sharded_optim = optimizer_class(zero_model.parameters(), lr=1e-3)
|
||||
sharded_optim = ShardedOptimizerV2(zero_model, sharded_optim, cpu_offload=cpu_offload, initial_scale=2**5)
|
||||
sharded_optim = ShardedOptimizerV2(zero_model,
|
||||
sharded_optim,
|
||||
cpu_offload=cpu_offload,
|
||||
initial_scale=2**5,
|
||||
gpu_margin_mem_ratio=gpu_margin_mem_ratio)
|
||||
|
||||
amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False)
|
||||
apex_model, apex_optimizer = convert_to_apex_amp(model, optim, amp_config)
|
||||
|
Reference in New Issue
Block a user