mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-17 23:18:36 +00:00
[workflow] fixed build CI (#5240)
* [workflow] fixed build CI * polish * polish * polish * polish * polish
This commit is contained in:
@@ -7,6 +7,7 @@ from transformers import LlamaForCausalLM
|
||||
from utils import shared_tempdir
|
||||
|
||||
import colossalai
|
||||
from colossalai.testing import skip_if_not_enough_gpus
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import GeminiPlugin
|
||||
from colossalai.lazy import LazyInitContext
|
||||
@@ -68,7 +69,7 @@ def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: b
|
||||
@clear_cache_before_run()
|
||||
@parameterize("placement_config", OPTIM_PLACEMENT_CONFIGS)
|
||||
@parameterize("shard", [True, False])
|
||||
@parameterize("model_name", ["transformers_gpt"])
|
||||
@parameterize("model_name", ["transformers_llama_for_casual_lm"])
|
||||
@parameterize("size_per_shard", [32])
|
||||
@parameterize("tp_size", [1, 2])
|
||||
@parameterize("zero_size", [2])
|
||||
@@ -156,13 +157,12 @@ def run_dist(rank, world_size, port):
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize("world_size", [4])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_gemini_ckpIO(world_size):
|
||||
spawn(run_dist, world_size)
|
||||
def test_gemini_ckpIO():
|
||||
spawn(run_dist, 4)
|
||||
|
||||
@pytest.mark.largedist
|
||||
@pytest.mark.parametrize("world_size", [8])
|
||||
@skip_if_not_enough_gpus(min_gpus=8)
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_gemini_ckpIO_3d(world_size):
|
||||
spawn(run_dist, world_size)
|
||||
def test_gemini_ckpIO_3d():
|
||||
spawn(run_dist, 8)
|
@@ -20,7 +20,7 @@ from tests.kit.model_zoo import model_zoo
|
||||
|
||||
@clear_cache_before_run()
|
||||
@parameterize("shard", [False, True])
|
||||
@parameterize("model_name", ["transformers_gpt"])
|
||||
@parameterize("model_name", ["transformers_llama_for_casual_lm"])
|
||||
def exam_torch_load_from_gemini(shard: bool, model_name: str):
|
||||
(model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values()))
|
||||
criterion = lambda x: x.mean()
|
||||
|
@@ -40,7 +40,7 @@ else:
|
||||
|
||||
@clear_cache_before_run()
|
||||
@parameterize("shard", [True, False])
|
||||
@parameterize("model_name", ["transformers_gpt"])
|
||||
@parameterize("model_name", ["transformers_llama_for_casual_lm"])
|
||||
@parameterize("size_per_shard", [32])
|
||||
@parameterize("test_config", TEST_CONFIGS)
|
||||
def exam_state_dict(shard: bool, model_name: str, size_per_shard: int, test_config: dict):
|
||||
|
@@ -18,7 +18,7 @@ from tests.kit.model_zoo import model_zoo
|
||||
|
||||
|
||||
@clear_cache_before_run()
|
||||
@parameterize("model_name", ["transformers_gpt"])
|
||||
@parameterize("model_name", ["transformers_llama_for_casual_lm"])
|
||||
@parameterize("plugin_type", ["ddp", "zero", "gemini"])
|
||||
def exam_from_pretrained(plugin_type: str, model_name: str, shard=True, size_per_shard=32):
|
||||
(model_fn, data_gen_fn, output_transform_fn, loss_fn, _) = next(
|
||||
|
Reference in New Issue
Block a user