mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 20:40:34 +00:00
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
committed by
アマデウス
parent
4c69e2dc91
commit
df612434c9
@@ -1,33 +1,33 @@
|
||||
import os
|
||||
|
||||
from . import custom, diffusers, timm, torchaudio, torchvision, transformers
|
||||
from .executor import run_fwd, run_fwd_bwd
|
||||
from .registry import model_zoo
|
||||
|
||||
# We pick a subset of models for fast testing in order to reduce the total testing time
|
||||
COMMON_MODELS = [
|
||||
'custom_hanging_param_model',
|
||||
'custom_nested_model',
|
||||
'custom_repeated_computed_layers',
|
||||
'custom_simple_net',
|
||||
'diffusers_clip_text_model',
|
||||
'diffusers_auto_encoder_kl',
|
||||
'diffusers_unet2d_model',
|
||||
'timm_densenet',
|
||||
'timm_resnet',
|
||||
'timm_swin_transformer',
|
||||
'torchaudio_wav2vec2_base',
|
||||
'torchaudio_conformer',
|
||||
'transformers_bert_for_masked_lm',
|
||||
'transformers_bloom_for_causal_lm',
|
||||
'transformers_falcon_for_causal_lm',
|
||||
'transformers_chatglm_for_conditional_generation',
|
||||
'transformers_llama_for_casual_lm',
|
||||
'transformers_vit_for_masked_image_modeling',
|
||||
'transformers_mistral_for_casual_lm'
|
||||
"custom_hanging_param_model",
|
||||
"custom_nested_model",
|
||||
"custom_repeated_computed_layers",
|
||||
"custom_simple_net",
|
||||
"diffusers_clip_text_model",
|
||||
"diffusers_auto_encoder_kl",
|
||||
"diffusers_unet2d_model",
|
||||
"timm_densenet",
|
||||
"timm_resnet",
|
||||
"timm_swin_transformer",
|
||||
"torchaudio_wav2vec2_base",
|
||||
"torchaudio_conformer",
|
||||
"transformers_bert_for_masked_lm",
|
||||
"transformers_bloom_for_causal_lm",
|
||||
"transformers_falcon_for_causal_lm",
|
||||
"transformers_chatglm_for_conditional_generation",
|
||||
"transformers_llama_for_casual_lm",
|
||||
"transformers_vit_for_masked_image_modeling",
|
||||
"transformers_mistral_for_casual_lm",
|
||||
]
|
||||
|
||||
IS_FAST_TEST = os.environ.get('FAST_TEST', '0') == '1'
|
||||
IS_FAST_TEST = os.environ.get("FAST_TEST", "0") == "1"
|
||||
|
||||
|
||||
__all__ = ["model_zoo", "run_fwd", "run_fwd_bwd", 'COMMON_MODELS', 'IS_FAST_TEST']
|
||||
|
||||
__all__ = ["model_zoo", "run_fwd", "run_fwd_bwd", "COMMON_MODELS", "IS_FAST_TEST"]
|
||||
|
@@ -102,4 +102,4 @@ class ModelZooRegistry(dict):
|
||||
return new_dict
|
||||
|
||||
|
||||
model_zoo = ModelZooRegistry()
|
||||
model_zoo = ModelZooRegistry()
|
||||
|
@@ -2,6 +2,7 @@ import torch
|
||||
|
||||
from colossalai.shardformer.modeling.chatglm2_6b.configuration_chatglm import ChatGLMConfig
|
||||
from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration, ChatGLMModel
|
||||
|
||||
from ..registry import ModelAttribute, model_zoo
|
||||
|
||||
# ================================
|
||||
|
Reference in New Issue
Block a user