mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2026-01-29 21:49:54 +00:00
[workflow] fixed build CI (#5240)
* [workflow] fixed build CI * polish * polish * polish * polish * polish
This commit is contained in:
@@ -1,5 +1,33 @@
|
||||
from . import custom, diffusers, timm, torchaudio, torchrec, torchvision, transformers
|
||||
import os
|
||||
from . import custom, diffusers, timm, torchaudio, torchvision, transformers
|
||||
from .executor import run_fwd, run_fwd_bwd
|
||||
from .registry import model_zoo
|
||||
|
||||
__all__ = ["model_zoo", "run_fwd", "run_fwd_bwd"]
|
||||
# We pick a subset of models for fast testing in order to reduce the total testing time
|
||||
COMMON_MODELS = [
|
||||
'custom_hanging_param_model',
|
||||
'custom_nested_model',
|
||||
'custom_repeated_computed_layers',
|
||||
'custom_simple_net',
|
||||
'diffusers_clip_text_model',
|
||||
'diffusers_auto_encoder_kl',
|
||||
'diffusers_unet2d_model',
|
||||
'timm_densenet',
|
||||
'timm_resnet',
|
||||
'timm_swin_transformer',
|
||||
'torchaudio_wav2vec2_base',
|
||||
'torchaudio_conformer',
|
||||
'transformers_bert_for_masked_lm',
|
||||
'transformers_bloom_for_causal_lm',
|
||||
'transformers_falcon_for_causal_lm',
|
||||
'transformers_chatglm_for_conditional_generation',
|
||||
'transformers_llama_for_casual_lm',
|
||||
'transformers_vit_for_masked_image_modeling',
|
||||
'transformers_mistral_for_casual_lm'
|
||||
]
|
||||
|
||||
IS_FAST_TEST = os.environ.get('FAST_TEST', '0') == '1'
|
||||
|
||||
|
||||
__all__ = ["model_zoo", "run_fwd", "run_fwd_bwd", 'COMMON_MODELS', 'IS_FAST_TEST']
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
#!/usr/bin/env python
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable
|
||||
from typing import Callable, List, Union
|
||||
|
||||
__all__ = ["ModelZooRegistry", "ModelAttribute", "model_zoo"]
|
||||
|
||||
@@ -61,7 +61,7 @@ class ModelZooRegistry(dict):
|
||||
"""
|
||||
self[name] = (model_fn, data_gen_fn, output_transform_fn, loss_fn, model_attribute)
|
||||
|
||||
def get_sub_registry(self, keyword: str):
|
||||
def get_sub_registry(self, keyword: Union[str, List[str]]):
|
||||
"""
|
||||
Get a sub registry with models that contain the keyword.
|
||||
|
||||
@@ -70,12 +70,15 @@ class ModelZooRegistry(dict):
|
||||
"""
|
||||
new_dict = dict()
|
||||
|
||||
if isinstance(keyword, str):
|
||||
keyword_list = [keyword]
|
||||
else:
|
||||
keyword_list = keyword
|
||||
assert isinstance(keyword_list, (list, tuple))
|
||||
|
||||
for k, v in self.items():
|
||||
if keyword == "transformers_gpt":
|
||||
if keyword in k and not "gptj" in k: # ensure GPT2 does not retrieve GPTJ models
|
||||
new_dict[k] = v
|
||||
else:
|
||||
if keyword in k:
|
||||
for kw in keyword_list:
|
||||
if kw in k:
|
||||
new_dict[k] = v
|
||||
|
||||
assert len(new_dict) > 0, f"No model found with keyword {keyword}"
|
||||
|
||||
Reference in New Issue
Block a user