diff --git a/tests/test_tensor/test_zero_optim.py b/tests/test_tensor/test_zero_optim.py index 3b374fa1e..b08ceed32 100644 --- a/tests/test_tensor/test_zero_optim.py +++ b/tests/test_tensor/test_zero_optim.py @@ -18,6 +18,7 @@ from colossalai.testing import parameterize from colossalai.amp import convert_to_apex_amp from colossalai.gemini.gemini_mgr import GeminiManager from colossalai.tensor import ColoTensorSpec, ShardSpec, ComputePattern, ComputeSpec, ProcessGroup, ColoTensor +from tests.test_tensor.model.test_gpt2 import init_megatron_spec def check_param_equal(model, torch_model, pg: ProcessGroup): @@ -127,10 +128,10 @@ def run_dist(rank, world_size, port): config = {} colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') if world_size == 4: - run_gpt(tp_init_spec_func=init_1d_col_spec) - run_gpt(tp_init_spec_func=init_1d_row_spec) + run_gpt(tp_init_spec_func=init_megatron_spec) else: run_gpt(tp_init_spec_func=init_1d_col_spec) + run_gpt(tp_init_spec_func=init_1d_row_spec) @pytest.mark.dist