ColossalAI/tests/kit/model_zoo/__init__.py
Edenzzzz f5c84af0b0
[Feature] Zigzag Ring attention (#5905)
* halfway

* fix cross-PP-stage position id length diff bug

* fix typo

* fix typo

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* unified cross entropy func for all shardformer models

* remove redundant lines

* add basic ring attn; debug cross entropy

* fwd bwd logic complete

* fwd bwd logic complete; add experimental triton rescale

* precision tests passed

* precision tests passed

* fix typos and remove misc files

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add sp_mode to benchmark; fix varlen interface

* update softmax_lse shape by new interface

* change tester name

* remove buffer clone; support packed seq layout

* add varlen tests

* fix typo

* all tests passed

* add dkv_group; fix mask

* remove debug statements

---------

Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2024-08-16 13:56:38 +08:00

34 lines
1.0 KiB
Python

import os
from . import custom, diffusers, timm, torchaudio, torchvision, transformers
from .executor import run_fwd, run_fwd_bwd
from .registry import model_zoo
# We pick a subset of models for fast testing in order to reduce the total testing time
COMMON_MODELS = [
"custom_hanging_param_model",
"custom_nested_model",
"custom_repeated_computed_layers",
"custom_simple_net",
"diffusers_clip_text_model",
"diffusers_auto_encoder_kl",
"diffusers_unet2d_model",
"timm_densenet",
"timm_resnet",
"timm_swin_transformer",
"torchaudio_wav2vec2_base",
"torchaudio_conformer",
"transformers_bert_for_masked_lm",
"transformers_bloom_for_causal_lm",
"transformers_falcon_for_causal_lm",
"transformers_chatglm_for_conditional_generation",
"transformers_llama_for_causal_lm",
"transformers_vit_for_masked_image_modeling",
"transformers_mistral_for_causal_lm",
]
IS_FAST_TEST = os.environ.get("FAST_TEST", "0") == "1"
__all__ = ["model_zoo", "run_fwd", "run_fwd_bwd", "COMMON_MODELS", "IS_FAST_TEST"]