[zero] improve adaptability for not-shard parameters (#708)

* adapt post grad hooks for not-shard parameters
* adapt optimizer for not-shard parameters
* offload gradients for not-replicated parameters
This commit is contained in:
HELSON
2022-04-11 13:38:51 +08:00
committed by GitHub
parent ab8c6b4a0e
commit a9b8300d54
9 changed files with 114 additions and 111 deletions

View File

@@ -48,8 +48,13 @@ def _run_step(model, optimizer, data, label, criterion, grad_handler):
@parameterize("cpu_offload", [True])
@parameterize("use_cpuadam", [True]) # We do not use Hybrid Adam right now, since it has a little bug
@parameterize("reuse_fp16_shard", [True, False])
@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy])
def _run_test_sharded_optim_v2(cpu_offload, shard_strategy_class, use_cpuadam, gpu_margin_mem_ratio=0.0):
def _run_test_sharded_optim_v2(cpu_offload,
shard_strategy_class,
use_cpuadam,
reuse_fp16_shard,
gpu_margin_mem_ratio=0.0):
shard_strategy = shard_strategy_class()
if use_cpuadam and cpu_offload is False:
return
@@ -63,17 +68,15 @@ def _run_test_sharded_optim_v2(cpu_offload, shard_strategy_class, use_cpuadam, g
shard_param=True):
zero_model = MoeModel()
zero_model = ShardedModelV2(
zero_model,
shard_strategy,
offload_config=dict(device='cpu') if cpu_offload else None,
use_memory_tracer=gpu_margin_mem_ratio > 0.0,
reuse_fp16_shard=use_cpuadam,
)
zero_model = ShardedModelV2(zero_model,
shard_strategy,
offload_config=dict(device='cpu') if cpu_offload else None,
use_memory_tracer=gpu_margin_mem_ratio > 0.0,
reuse_fp16_shard=reuse_fp16_shard)
# check whether parameters are identical in ddp
for name, p in zero_model.named_parameters():
if not p.colo_attr.param_is_sharded and p.is_replicated:
if not p.colo_attr.param_is_sharded and p.colo_attr.is_replicated:
assert_equal_in_group(p.colo_attr.sharded_data_tensor.payload.to(get_current_device()))
model = MoeModel().half()
@@ -88,8 +91,7 @@ def _run_test_sharded_optim_v2(cpu_offload, shard_strategy_class, use_cpuadam, g
sharded_optim,
cpu_offload=cpu_offload,
initial_scale=2**5,
gpu_margin_mem_ratio=gpu_margin_mem_ratio,
keep_unsharded=True)
gpu_margin_mem_ratio=gpu_margin_mem_ratio)
amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False)
apex_model, apex_optimizer = convert_to_apex_amp(model, optim, amp_config)