From f6178728a0319fe58d02936e55dc08dfc19aca73 Mon Sep 17 00:00:00 2001 From: HELSON Date: Wed, 30 Nov 2022 17:06:10 +0800 Subject: [PATCH] [gemini] fix init bugs for modules (#2047) * [gemini] fix init bugs for modules * fix bugs --- colossalai/utils/model/colo_init_context.py | 5 ----- tests/test_gemini/update/test_optim.py | 17 +++++++---------- 2 files changed, 7 insertions(+), 15 deletions(-) diff --git a/colossalai/utils/model/colo_init_context.py b/colossalai/utils/model/colo_init_context.py index b7fef99b4..7a9b3ff25 100644 --- a/colossalai/utils/model/colo_init_context.py +++ b/colossalai/utils/model/colo_init_context.py @@ -96,10 +96,6 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses): The function to call at the end of the constructor of each module. FIXME(fjr) The module may be passed to this function multiple times? """ - - if hasattr(module, '_colo_visited'): - return - name_list = [] for name, param in _named_params_with_replica(module): if isinstance(param, ColoTensor): @@ -130,7 +126,6 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses): colo_param.shared_param_modules.append(submodule) module.to(self._device) - ColoModulize(module) def post_process_colo_init_ctx(model: torch.nn.Module, diff --git a/tests/test_gemini/update/test_optim.py b/tests/test_gemini/update/test_optim.py index 93164995d..8dce2915a 100644 --- a/tests/test_gemini/update/test_optim.py +++ b/tests/test_gemini/update/test_optim.py @@ -24,6 +24,11 @@ from tests.components_to_test import run_fwd_bwd from tests.components_to_test.registry import non_distributed_component_funcs from tests.test_tensor.common_utils import debug_print, set_seed +# this model is large enough to slice to chunks +TEST_MODELS = ['gpt2'] +# these models are too small, all parameters in these models are compacted into one chunk +EXAMPLE_MODELS = ['hanging_param_model', 'bert', 'simple_net', 'nested_model', 'repeated_computed_layers'] + def check_param(model: ZeroDDP, torch_model: torch.nn.Module): zero_dict = model.state_dict(only_rank_0=False) @@ -40,10 +45,6 @@ def check_param(model: ZeroDDP, torch_model: torch.nn.Module): assert_close(value, temp_zero_value, rtol=1e-3, atol=1e-2) -# 'gpt2', 'bert', -TEST_MODELS = ['hanging_param_model', 'gpt2', 'bert', 'simple_net', 'nested_model', 'repeated_computed_layers'] - - @parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const']) @parameterize('model_name', TEST_MODELS) def exam_model_step(placement_policy, model_name: str): @@ -61,8 +62,6 @@ def exam_model_step(placement_policy, model_name: str): with ColoInitContext(device=init_dev): model = model_builder() - post_process_colo_init_ctx(model, device=init_dev) - for torch_p, p in zip(torch_model.parameters(), model.parameters()): p.data.copy_(torch_p.data) @@ -102,8 +101,8 @@ def exam_model_step(placement_policy, model_name: str): check_param(model, torch_model) -@parameterize('placement_policy', ['cuda', 'cpu']) -@parameterize('model_name', TEST_MODELS) +@parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const']) +@parameterize('model_name', EXAMPLE_MODELS) def exam_tiny_example(placement_policy, model_name: str): set_seed(2008) get_components_func = non_distributed_component_funcs.get_callable(model_name) @@ -119,8 +118,6 @@ def exam_tiny_example(placement_policy, model_name: str): with ColoInitContext(device=init_dev): model = model_builder() - post_process_colo_init_ctx(model, device=init_dev) - for torch_p, p in zip(torch_model.parameters(), model.parameters()): p.data.copy_(torch_p.data)