mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-12 20:54:35 +00:00
[hotfix] hotfix Gemini for no leaf modules bug (#2043)
This commit is contained in:
@@ -15,10 +15,11 @@ from colossalai.gemini.gemini_mgr import GeminiManager
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.nn.optimizer.zero_optimizer import ZeroOptimizer
|
||||
from colossalai.nn.parallel import ZeroDDP
|
||||
from colossalai.tensor import ColoParameter, ColoTensor
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||
from colossalai.utils.model.colo_init_context import ColoInitContext, post_process_colo_init_ctx
|
||||
from tests.components_to_test import run_fwd_bwd
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
from tests.test_tensor.common_utils import debug_print, set_seed
|
||||
@@ -40,8 +41,7 @@ def check_param(model: ZeroDDP, torch_model: torch.nn.Module):
|
||||
|
||||
|
||||
# 'gpt2', 'bert',
|
||||
TEST_MODELS = ['gpt2', 'bert']
|
||||
EXAMPLE_MODELS = ['simple_net']
|
||||
TEST_MODELS = ['no_leaf_module', 'gpt2', 'bert', 'simple_net', 'nested_model', 'repeated_computed_layers']
|
||||
|
||||
|
||||
@parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const'])
|
||||
@@ -57,8 +57,12 @@ def exam_model_step(placement_policy, model_name: str):
|
||||
torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config)
|
||||
torch_model = DDP(torch_model, device_ids=[dist.get_rank()])
|
||||
|
||||
with ColoInitContext(device=get_current_device()):
|
||||
init_dev = get_current_device()
|
||||
with ColoInitContext(device=init_dev):
|
||||
model = model_builder()
|
||||
|
||||
post_process_colo_init_ctx(model, device=init_dev)
|
||||
|
||||
for torch_p, p in zip(torch_model.parameters(), model.parameters()):
|
||||
p.data.copy_(torch_p.data)
|
||||
|
||||
@@ -99,7 +103,7 @@ def exam_model_step(placement_policy, model_name: str):
|
||||
|
||||
|
||||
@parameterize('placement_policy', ['cuda', 'cpu'])
|
||||
@parameterize('model_name', EXAMPLE_MODELS)
|
||||
@parameterize('model_name', TEST_MODELS)
|
||||
def exam_tiny_example(placement_policy, model_name: str):
|
||||
set_seed(2008)
|
||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||
@@ -111,8 +115,12 @@ def exam_tiny_example(placement_policy, model_name: str):
|
||||
torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config)
|
||||
torch_model = DDP(torch_model, device_ids=[dist.get_rank()])
|
||||
|
||||
with ColoInitContext(device=get_current_device()):
|
||||
init_dev = get_current_device()
|
||||
with ColoInitContext(device=init_dev):
|
||||
model = model_builder()
|
||||
|
||||
post_process_colo_init_ctx(model, device=init_dev)
|
||||
|
||||
for torch_p, p in zip(torch_model.parameters(), model.parameters()):
|
||||
p.data.copy_(torch_p.data)
|
||||
|
||||
|
Reference in New Issue
Block a user