[hotfix] hotfix Gemini for no leaf modules bug (#2043)

This commit is contained in:
Jiarui Fang
2022-11-30 14:53:41 +08:00
committed by GitHub
parent 384cd26314
commit 31c644027b
2 changed files with 82 additions and 28 deletions

View File

@@ -15,10 +15,11 @@ from colossalai.gemini.gemini_mgr import GeminiManager
from colossalai.nn.optimizer import HybridAdam
from colossalai.nn.optimizer.zero_optimizer import ZeroOptimizer
from colossalai.nn.parallel import ZeroDDP
from colossalai.tensor import ColoParameter, ColoTensor
from colossalai.testing import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.utils.cuda import get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.utils.model.colo_init_context import ColoInitContext, post_process_colo_init_ctx
from tests.components_to_test import run_fwd_bwd
from tests.components_to_test.registry import non_distributed_component_funcs
from tests.test_tensor.common_utils import debug_print, set_seed
@@ -40,8 +41,7 @@ def check_param(model: ZeroDDP, torch_model: torch.nn.Module):
# 'gpt2', 'bert',
TEST_MODELS = ['gpt2', 'bert']
EXAMPLE_MODELS = ['simple_net']
TEST_MODELS = ['no_leaf_module', 'gpt2', 'bert', 'simple_net', 'nested_model', 'repeated_computed_layers']
@parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const'])
@@ -57,8 +57,12 @@ def exam_model_step(placement_policy, model_name: str):
torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config)
torch_model = DDP(torch_model, device_ids=[dist.get_rank()])
with ColoInitContext(device=get_current_device()):
init_dev = get_current_device()
with ColoInitContext(device=init_dev):
model = model_builder()
post_process_colo_init_ctx(model, device=init_dev)
for torch_p, p in zip(torch_model.parameters(), model.parameters()):
p.data.copy_(torch_p.data)
@@ -99,7 +103,7 @@ def exam_model_step(placement_policy, model_name: str):
@parameterize('placement_policy', ['cuda', 'cpu'])
@parameterize('model_name', EXAMPLE_MODELS)
@parameterize('model_name', TEST_MODELS)
def exam_tiny_example(placement_policy, model_name: str):
set_seed(2008)
get_components_func = non_distributed_component_funcs.get_callable(model_name)
@@ -111,8 +115,12 @@ def exam_tiny_example(placement_policy, model_name: str):
torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config)
torch_model = DDP(torch_model, device_ids=[dist.get_rank()])
with ColoInitContext(device=get_current_device()):
init_dev = get_current_device()
with ColoInitContext(device=init_dev):
model = model_builder()
post_process_colo_init_ctx(model, device=init_dev)
for torch_p, p in zip(torch_model.parameters(), model.parameters()):
p.data.copy_(torch_p.data)