mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2026-05-05 12:24:38 +00:00
[zero] refactor low level zero for shard evenly (#4030)
* refactor low level zero * fix zero2 and support cpu offload * avg gradient and modify unit test * refactor grad store, support layer drop * refactor bucket store, support grad accumulation * fix and update unit test of zero and ddp * compatible with tp, ga and unit test * fix memory leak and polish * add zero layer drop unittest * polish code * fix import err in unit test * support diffenert comm dtype, modify docstring style * polish code * test padding and fix * fix unit test of low level zero * fix pad recording in bucket store * support some models * polish
This commit is contained in:
@@ -11,14 +11,9 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
|
||||
# These models are not compatible with AMP
|
||||
_AMP_ERR_MODELS = ['timm_convit', 'dlrm', 'deepfm_interactionarch', 'deepfm_simpledeepfmnn']
|
||||
_AMP_ERR_MODELS = ['timm_convit', 'deepfm_interactionarch']
|
||||
# These models have no parameters
|
||||
_LOW_LEVEL_ZERO_ERR_MODELS = ['dlrm_interactionarch', 'deepfm_overarch', 'deepfm_sparsearch', 'dlrm_sparsearch']
|
||||
# These models will get stuck
|
||||
_STUCK_MODELS = [
|
||||
'diffusers_vq_model', 'transformers_albert', 'transformers_albert_for_pretraining', 'transformers_bert',
|
||||
'transformers_bert_for_pretraining', 'transformers_gpt_double_heads'
|
||||
]
|
||||
_LOW_LEVEL_ZERO_ERR_MODELS = ['dlrm_interactionarch']
|
||||
|
||||
|
||||
def run_fn(stage, model_fn, data_gen_fn, output_transform_fn) -> Optional[str]:
|
||||
@@ -58,7 +53,7 @@ def check_low_level_zero_plugin(stage: int, early_stop: bool = True):
|
||||
"""
|
||||
passed_models = []
|
||||
failed_info = {} # (model_name, error) pair
|
||||
ignore_models = _AMP_ERR_MODELS + _LOW_LEVEL_ZERO_ERR_MODELS + _STUCK_MODELS
|
||||
ignore_models = _AMP_ERR_MODELS + _LOW_LEVEL_ZERO_ERR_MODELS
|
||||
skipped_models = []
|
||||
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in model_zoo.items():
|
||||
|
||||
Reference in New Issue
Block a user