[Test/CI] remove test cases to reduce CI duration (#5753)

* [test] smaller gpt2 test case

* [test] reduce test cases: tests/test_zero/test_gemini/test_zeroddp_state_dict.py

* [test] reduce test cases: tests/test_zero/test_gemini/test_grad_accum.py

* [test] reduce test cases tests/test_zero/test_gemini/test_optim.py

* Revert "[test] smaller gpt2 test case"

Some tests might depend on the size of model (num of chunks)

This reverts commit df705a5210.

* [test] reduce test cases: tests/test_checkpoint_io/test_gemini_checkpoint_io.py

* [CI] smaller test model for two mwo the two modifid cases

* [CI] hardcode gpt model for tests/test_zero/test_gemini/test_search.py since we need a fixed answer there
This commit is contained in:
botbw
2024-06-05 11:29:04 +08:00
committed by GitHub
parent 79f7a7b211
commit 80c3c8789b
6 changed files with 40 additions and 76 deletions

View File

@@ -1,18 +1,31 @@
import pytest
import torch
import transformers
import colossalai
from colossalai.accelerator import get_accelerator
from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.zero.gemini.chunk import init_chunk_manager, search_chunk_configuration
from tests.kit.model_zoo import model_zoo
CONFIG = transformers.GPT2Config(
n_layer=2,
n_head=4,
n_embd=128,
vocab_size=50258,
attn_pdrop=0,
embd_pdrop=0,
resid_pdrop=0,
summary_first_dropout=0,
hidden_dropout=0,
problem_type="single_label_classification",
pad_token_id=50256,
tie_word_embeddings=True,
)
model_builder = lambda: transformers.GPT2LMHeadModel(CONFIG)
def exam_search_chunk_size():
model_builder, data_gen_fn, output_transform_fn, *_ = next(
iter(model_zoo.get_sub_registry("transformers_gpt_lm").values())
)
# make sure torch_model and model has the same parameter values
model = model_builder()
config_dict, *_ = search_chunk_configuration(
@@ -27,10 +40,6 @@ def exam_search_chunk_size():
def exam_chunk_manager():
world_size = torch.distributed.get_world_size()
model_builder, data_gen_fn, output_transform_fn, *_ = next(
iter(model_zoo.get_sub_registry("transformers_gpt_lm").values())
)
sharded_ddp_model = model_builder()
chunk_manager = init_chunk_manager(
sharded_ddp_model,