From 17b8274f8a689c79cfea0571f74ac471b88e6543 Mon Sep 17 00:00:00 2001 From: Jiarui Fang Date: Thu, 17 Mar 2022 10:20:53 +0800 Subject: [PATCH] [unitest] polish zero config in unittest (#438) --- tests/test_zero_data_parallel/common.py | 12 +++++++ .../test_sharded_optim_with_sync_bn.py | 6 ++-- .../test_zero_init_v2.py | 31 ++++++------------- 3 files changed, 24 insertions(+), 25 deletions(-) diff --git a/tests/test_zero_data_parallel/common.py b/tests/test_zero_data_parallel/common.py index f9fe86bf6..4671dc3a5 100644 --- a/tests/test_zero_data_parallel/common.py +++ b/tests/test_zero_data_parallel/common.py @@ -10,6 +10,18 @@ from colossalai.zero.sharded_model import ShardedModelV2 LOGGER = get_dist_logger() +_ZERO_OPTIMIZER_CONFIG = dict(optimizer_type=torch.optim.Adam, optimizer_config=dict(lr=1e-3)) +_ZERO_OFFLOAD_OPTIMIZER_CONFIG = dict(device='cpu', pin_memory=True, buffer_count=5, fast_init=False) +_ZERO_OFFLOAD_PARAM_CONFIG = dict(device='cpu', pin_memory=True, buffer_count=5, buffer_size=1e8, max_in_cpu=1e9) + +ZERO_PARALLEL_CONFIG = dict(fp16=dict(mode=None,), + zero=dict( + optimzer=_ZERO_OPTIMIZER_CONFIG, + offload_optimizer_config=_ZERO_OFFLOAD_OPTIMIZER_CONFIG, + offload_param_config=_ZERO_OFFLOAD_PARAM_CONFIG, + ), + parallel=dict(pipeline=dict(size=1), tensor=dict(size=1, mode=None))) + CONFIG = dict(fp16=dict(mode=None,), zero=dict(level=3, verbose=False, diff --git a/tests/test_zero_data_parallel/test_sharded_optim_with_sync_bn.py b/tests/test_zero_data_parallel/test_sharded_optim_with_sync_bn.py index da1f4edf2..2eecc4802 100644 --- a/tests/test_zero_data_parallel/test_sharded_optim_with_sync_bn.py +++ b/tests/test_zero_data_parallel/test_sharded_optim_with_sync_bn.py @@ -19,10 +19,8 @@ def run_dist(rank, world_size, port): # as this model has sync batch normalization # need to configure cudnn deterministic so that # randomness of convolution layers will be disabled - colossalai.launch(config=dict( - zero=dict(optimzer=dict(optimizer_type=torch.optim.Adam, optimizer_config=dict(lr=1e-3))), - cudnn_determinstic=True, - cudnn_benchmark=False), + zero_config = dict(optimzer=dict(optimizer_type=torch.optim.Adam, optimizer_config=dict(lr=1e-3))) + colossalai.launch(config=dict(zero=zero_config, cudnn_determinstic=True, cudnn_benchmark=False), rank=rank, world_size=world_size, host='localhost', diff --git a/tests/test_zero_data_parallel/test_zero_init_v2.py b/tests/test_zero_data_parallel/test_zero_init_v2.py index cc6d3d4d3..f7696eef5 100644 --- a/tests/test_zero_data_parallel/test_zero_init_v2.py +++ b/tests/test_zero_data_parallel/test_zero_init_v2.py @@ -8,32 +8,21 @@ import pytest import colossalai from colossalai.utils import free_port -import torch import torch.multiprocessing as mp from tests.components_to_test.registry import non_distributed_component_funcs - -from common import check_sharded_params_padding +from common import check_sharded_params_padding, ZERO_PARALLEL_CONFIG def run_dist(rank, world_size, port): - _config = dict(fp16=dict(mode=None,), - zero=dict(optimzer=dict(optimizer_type=torch.optim.Adam, optimizer_config=dict(lr=1e-3)), - offload_optimizer_config=dict(device='cpu', - pin_memory=True, - buffer_count=5, - fast_init=False), - offload_param_config=dict(device='cpu', - pin_memory=True, - buffer_count=5, - buffer_size=1e8, - max_in_cpu=1e9)), - parallel=dict(pipeline=dict(size=1), tensor=dict(size=1, mode=None))) + colossalai.launch(config=ZERO_PARALLEL_CONFIG, + rank=rank, + world_size=world_size, + host='localhost', + port=port, + backend='nccl') - colossalai.launch(config=_config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - # FIXME revert back - # test_models = ['repeated_computed_layers', 'resnet18', 'bert'] - test_models = ['bert'] + test_models = ['repeated_computed_layers', 'resnet18', 'bert'] for model_name in test_models: get_components_func = non_distributed_component_funcs.get_callable(model_name) model_builder, train_dataloader, _, optimizer_class, criterion = get_components_func() @@ -65,8 +54,8 @@ def run_dist(rank, world_size, port): output = engine(data) loss = engine.criterion(output, label) - torch_model(data, label) - torch_loss = engine.criterion(output, label) + torch_output = torch_model(data) + torch_loss = engine.criterion(torch_output, label) else: loss = engine(data, label) torch_loss = torch_model(data, label)