[gemini] fix init bugs for modules (#2047)

* [gemini] fix init bugs for modules

* fix bugs
This commit is contained in:
HELSON
2022-11-30 17:06:10 +08:00
committed by GitHub
parent 81e0da7fa8
commit f6178728a0
2 changed files with 7 additions and 15 deletions

View File

@@ -24,6 +24,11 @@ from tests.components_to_test import run_fwd_bwd
from tests.components_to_test.registry import non_distributed_component_funcs
from tests.test_tensor.common_utils import debug_print, set_seed
# this model is large enough to slice to chunks
TEST_MODELS = ['gpt2']
# these models are too small, all parameters in these models are compacted into one chunk
EXAMPLE_MODELS = ['hanging_param_model', 'bert', 'simple_net', 'nested_model', 'repeated_computed_layers']
def check_param(model: ZeroDDP, torch_model: torch.nn.Module):
zero_dict = model.state_dict(only_rank_0=False)
@@ -40,10 +45,6 @@ def check_param(model: ZeroDDP, torch_model: torch.nn.Module):
assert_close(value, temp_zero_value, rtol=1e-3, atol=1e-2)
# 'gpt2', 'bert',
TEST_MODELS = ['hanging_param_model', 'gpt2', 'bert', 'simple_net', 'nested_model', 'repeated_computed_layers']
@parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const'])
@parameterize('model_name', TEST_MODELS)
def exam_model_step(placement_policy, model_name: str):
@@ -61,8 +62,6 @@ def exam_model_step(placement_policy, model_name: str):
with ColoInitContext(device=init_dev):
model = model_builder()
post_process_colo_init_ctx(model, device=init_dev)
for torch_p, p in zip(torch_model.parameters(), model.parameters()):
p.data.copy_(torch_p.data)
@@ -102,8 +101,8 @@ def exam_model_step(placement_policy, model_name: str):
check_param(model, torch_model)
@parameterize('placement_policy', ['cuda', 'cpu'])
@parameterize('model_name', TEST_MODELS)
@parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const'])
@parameterize('model_name', EXAMPLE_MODELS)
def exam_tiny_example(placement_policy, model_name: str):
set_seed(2008)
get_components_func = non_distributed_component_funcs.get_callable(model_name)
@@ -119,8 +118,6 @@ def exam_tiny_example(placement_policy, model_name: str):
with ColoInitContext(device=init_dev):
model = model_builder()
post_process_colo_init_ctx(model, device=init_dev)
for torch_p, p in zip(torch_model.parameters(), model.parameters()):
p.data.copy_(torch_p.data)