mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2026-05-11 05:05:35 +00:00
[workflow] fixed build CI (#5240)
* [workflow] fixed build CI * polish * polish * polish * polish * polish
This commit is contained in:
@@ -13,7 +13,7 @@ from colossalai.lazy.lazy_init import LazyInitContext
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.tensor.colo_parameter import ColoParameter
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
from tests.kit.model_zoo import model_zoo, COMMON_MODELS, IS_FAST_TEST
|
||||
|
||||
|
||||
def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn, zero_size, tp_size) -> Optional[str]:
|
||||
@@ -66,7 +66,7 @@ def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn, zero_size, t
|
||||
# @parameterize('init_method', ['lazy', 'none', 'colo'])
|
||||
|
||||
|
||||
@parameterize("subset", ["torchvision", "transformers", "diffusers"])
|
||||
@parameterize("subset", [COMMON_MODELS] if IS_FAST_TEST else ["torchvision", "transformers", "diffusers"])
|
||||
@parameterize("init_method", ["none"])
|
||||
@parameterize("zero_size", [2])
|
||||
@parameterize("tp_size", [2])
|
||||
|
||||
@@ -11,7 +11,7 @@ from colossalai.booster.plugin import LowLevelZeroPlugin
|
||||
|
||||
# from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
from tests.kit.model_zoo import model_zoo, IS_FAST_TEST, COMMON_MODELS
|
||||
|
||||
# These models are not compatible with AMP
|
||||
_AMP_ERR_MODELS = ["timm_convit", "deepfm_interactionarch"]
|
||||
@@ -62,7 +62,12 @@ def check_low_level_zero_plugin(stage: int, early_stop: bool = True):
|
||||
ignore_models = _AMP_ERR_MODELS + _LOW_LEVEL_ZERO_ERR_MODELS + _STUCK_MODELS
|
||||
skipped_models = []
|
||||
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in model_zoo.items():
|
||||
if IS_FAST_TEST:
|
||||
registry = model_zoo.get_sub_registry(COMMON_MODELS)
|
||||
else:
|
||||
registry = model_zoo
|
||||
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in registry.items():
|
||||
# FIXME(ver217): fix these models
|
||||
if name in ignore_models:
|
||||
skipped_models.append(name)
|
||||
|
||||
@@ -11,7 +11,7 @@ from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import TorchDDPPlugin
|
||||
from colossalai.interface import OptimizerWrapper
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
from tests.kit.model_zoo import model_zoo, IS_FAST_TEST, COMMON_MODELS
|
||||
|
||||
|
||||
def run_fn(model_fn, data_gen_fn, output_transform_fn):
|
||||
@@ -40,7 +40,12 @@ def run_fn(model_fn, data_gen_fn, output_transform_fn):
|
||||
|
||||
|
||||
def check_torch_ddp_plugin():
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in model_zoo.items():
|
||||
if IS_FAST_TEST:
|
||||
registry = model_zoo.get_sub_registry(COMMON_MODELS)
|
||||
else:
|
||||
registry = model_zoo
|
||||
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in registry.items():
|
||||
if name == "dlrm_interactionarch":
|
||||
continue
|
||||
run_fn(model_fn, data_gen_fn, output_transform_fn)
|
||||
|
||||
@@ -12,7 +12,7 @@ if version.parse(torch.__version__) >= version.parse("1.12.0"):
|
||||
|
||||
from colossalai.interface import OptimizerWrapper
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
from tests.kit.model_zoo import model_zoo, IS_FAST_TEST, COMMON_MODELS
|
||||
|
||||
|
||||
# test basic fsdp function
|
||||
@@ -42,7 +42,12 @@ def run_fn(model_fn, data_gen_fn, output_transform_fn):
|
||||
|
||||
|
||||
def check_torch_fsdp_plugin():
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in model_zoo.items():
|
||||
if IS_FAST_TEST:
|
||||
registry = model_zoo.get_sub_registry(COMMON_MODELS)
|
||||
else:
|
||||
registry = model_zoo
|
||||
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in registry.items():
|
||||
if any(
|
||||
element in name
|
||||
for element in [
|
||||
|
||||
Reference in New Issue
Block a user