mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 20:10:17 +00:00
[workflow] fixed build CI (#5240)
* [workflow] fixed build CI * polish * polish * polish * polish * polish
This commit is contained in:
@@ -7,6 +7,7 @@ from transformers import LlamaForCausalLM
|
||||
from utils import shared_tempdir
|
||||
|
||||
import colossalai
|
||||
from colossalai.testing import skip_if_not_enough_gpus
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import GeminiPlugin
|
||||
from colossalai.lazy import LazyInitContext
|
||||
@@ -68,7 +69,7 @@ def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: b
|
||||
@clear_cache_before_run()
|
||||
@parameterize("placement_config", OPTIM_PLACEMENT_CONFIGS)
|
||||
@parameterize("shard", [True, False])
|
||||
@parameterize("model_name", ["transformers_gpt"])
|
||||
@parameterize("model_name", ["transformers_llama_for_casual_lm"])
|
||||
@parameterize("size_per_shard", [32])
|
||||
@parameterize("tp_size", [1, 2])
|
||||
@parameterize("zero_size", [2])
|
||||
@@ -156,13 +157,12 @@ def run_dist(rank, world_size, port):
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize("world_size", [4])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_gemini_ckpIO(world_size):
|
||||
spawn(run_dist, world_size)
|
||||
def test_gemini_ckpIO():
|
||||
spawn(run_dist, 4)
|
||||
|
||||
@pytest.mark.largedist
|
||||
@pytest.mark.parametrize("world_size", [8])
|
||||
@skip_if_not_enough_gpus(min_gpus=8)
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_gemini_ckpIO_3d(world_size):
|
||||
spawn(run_dist, world_size)
|
||||
def test_gemini_ckpIO_3d():
|
||||
spawn(run_dist, 8)
|
Reference in New Issue
Block a user