[unitest] polish zero config in unittest (#438)

This commit is contained in:
Jiarui Fang
2022-03-17 10:20:53 +08:00
committed by GitHub
parent 640a6cd304
commit 17b8274f8a
3 changed files with 24 additions and 25 deletions

View File

@@ -10,6 +10,18 @@ from colossalai.zero.sharded_model import ShardedModelV2
LOGGER = get_dist_logger()
_ZERO_OPTIMIZER_CONFIG = dict(optimizer_type=torch.optim.Adam, optimizer_config=dict(lr=1e-3))
_ZERO_OFFLOAD_OPTIMIZER_CONFIG = dict(device='cpu', pin_memory=True, buffer_count=5, fast_init=False)
_ZERO_OFFLOAD_PARAM_CONFIG = dict(device='cpu', pin_memory=True, buffer_count=5, buffer_size=1e8, max_in_cpu=1e9)
ZERO_PARALLEL_CONFIG = dict(fp16=dict(mode=None,),
zero=dict(
optimzer=_ZERO_OPTIMIZER_CONFIG,
offload_optimizer_config=_ZERO_OFFLOAD_OPTIMIZER_CONFIG,
offload_param_config=_ZERO_OFFLOAD_PARAM_CONFIG,
),
parallel=dict(pipeline=dict(size=1), tensor=dict(size=1, mode=None)))
CONFIG = dict(fp16=dict(mode=None,),
zero=dict(level=3,
verbose=False,