From 534afb018a7827a2e8182ad1be84dcec65eb3ef4 Mon Sep 17 00:00:00 2001 From: Jiarui Fang Date: Mon, 9 May 2022 17:07:35 +0800 Subject: [PATCH] test pretrain loading on multi-process (#922) --- tests/test_tensor/test_model.py | 64 +++++++++++++++++++++------------ 1 file changed, 41 insertions(+), 23 deletions(-) diff --git a/tests/test_tensor/test_model.py b/tests/test_tensor/test_model.py index c75242100..8f8fb4597 100644 --- a/tests/test_tensor/test_model.py +++ b/tests/test_tensor/test_model.py @@ -278,26 +278,6 @@ def test_colo_optimizer(): if i > 5: break -def _test_pretrained(): - from _utils import check_equal - from transformers import BertForMaskedLM - set_seed(1) - model_pretrained = BertForMaskedLM.from_pretrained('bert-base-uncased') - with ColoInitContext(lazy_memory_allocate=False, device=get_current_device()): - model = BertForMaskedLM.from_pretrained('bert-base-uncased') - - model_pretrained = model_pretrained.cuda() - model = model.cuda() - - dict_pretrained = {} - dict_col = {} - for name, param in model_pretrained.named_parameters(): - dict_pretrained[name] = param - for name, param in model.named_parameters(): - dict_col[name] = param - - for name, param in dict_pretrained.items(): - check_equal(param, dict_col[name]) def run_1d_row_tp(model_name: str): # A simple net with two stacked nn.Linear @@ -376,7 +356,29 @@ def run_1d_row_tp(model_name: str): break -def run_dist(rank, world_size, port): +def _run_pretrain_load(): + from _utils import check_equal + from transformers import BertForMaskedLM + set_seed(1) + model_pretrained = BertForMaskedLM.from_pretrained('bert-base-uncased') + with ColoInitContext(lazy_memory_allocate=False, device=get_current_device()): + model = BertForMaskedLM.from_pretrained('bert-base-uncased') + + model_pretrained = model_pretrained.cuda() + model = model.cuda() + + dict_pretrained = {} + dict_col = {} + for name, param in model_pretrained.named_parameters(): + dict_pretrained[name] = param + for name, param in model.named_parameters(): + dict_col[name] = param + + for name, param in dict_pretrained.items(): + check_equal(param, dict_col[name]) + + +def run_model_dist(rank, world_size, port): config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),)) colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') for name in ['simple_net']: @@ -390,7 +392,23 @@ def run_dist(rank, world_size, port): #@parameterize('world_size', [1, 4]) @rerun_if_address_is_in_use() def test_model(world_size): - run_func = partial(run_dist, world_size=world_size, port=free_port()) + run_func = partial(run_model_dist, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +def run_pretrain_load_dist(rank, world_size, port): + config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),)) + colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + _run_pretrain_load() + + +# The test case has to download huggingface pretrained models from the internet +# So we manually trigger the test. +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [1, 4]) +@rerun_if_address_is_in_use() +def _test_pretrain_load(world_size): + run_func = partial(run_pretrain_load_dist, world_size=world_size, port=free_port()) mp.spawn(run_func, nprocs=world_size) @@ -398,4 +416,4 @@ if __name__ == '__main__': # test_model_parameters() # test_colo_optimizer() # test_model() - _test_pretrained() + _test_pretrain_load(4)