Files
ColossalAI/tests/test_booster/test_plugin/test_gemini_plugin.py
flybird11111 29695cf70c [example]add gpt2 benchmark example script. (#5295)
* benchmark gpt2

* fix

fix

fix

fix

* [doc] fix typo in Colossal-LLaMA-2/README.md (#5247)

* [workflow] fixed build CI (#5240)

* [workflow] fixed build CI

* polish

* polish

* polish

* polish

* polish

* [ci] fixed booster test (#5251)

* [ci] fixed booster test

* [ci] fixed booster test

* [ci] fixed booster test

* [ci] fixed ddp test (#5254)

* [ci] fixed ddp test

* polish

* fix typo in  applications/ColossalEval/README.md (#5250)

* [ci] fix shardformer tests. (#5255)

* fix ci

fix

* revert: revert p2p

* feat: add enable_metadata_cache option

* revert: enable t5 tests

---------

Co-authored-by: Wenhao Chen <cwher@outlook.com>

* [doc] fix doc typo (#5256)

* [doc] fix annotation display

* [doc] fix llama2 doc

* [hotfix]: add pp sanity check and fix mbs arg (#5268)

* fix: fix misleading mbs arg

* feat: add pp sanity check

* fix: fix 1f1b sanity check

* [workflow] fixed incomplete bash command (#5272)

* [workflow] fixed oom tests (#5275)

* [workflow] fixed oom tests

* polish

* polish

* polish

* [ci] fix test_hybrid_parallel_plugin_checkpoint_io.py (#5276)

* fix ci

fix

* fix test

* revert: revert p2p

* feat: add enable_metadata_cache option

* revert: enable t5 tests

* fix

---------

Co-authored-by: Wenhao Chen <cwher@outlook.com>

* [shardformer] hybridparallelplugin support gradients accumulation. (#5246)

* support gradients acc

fix

fix

fix

fix

fix

fix

fix

fix

fix

fix

fix

fix

fix

* fix

fix

* fix

fix

fix

* [hotfix] Fix ShardFormer test execution path when using sequence parallelism (#5230)

* fix auto loading gpt2 tokenizer (#5279)

* [doc] add llama2-13B disyplay (#5285)

* Update README.md

* fix 13b typo

---------

Co-authored-by: binmakeswell <binmakeswell@gmail.com>

* fix llama pretrain (#5287)

* fix

* fix

* fix

fix

* fix

fix

fix

* fix

fix

* benchmark gpt2

* fix

fix

fix

fix

* [workflow] fixed build CI (#5240)

* [workflow] fixed build CI

* polish

* polish

* polish

* polish

* polish

* [ci] fixed booster test (#5251)

* [ci] fixed booster test

* [ci] fixed booster test

* [ci] fixed booster test

* fix

fix

* fix

fix

fix

* fix

* fix

fix

fix

fix

fix

* fix

* Update shardformer.py

---------

Co-authored-by: digger yu <digger-yu@outlook.com>
Co-authored-by: Frank Lee <somerlee.9@gmail.com>
Co-authored-by: Wenhao Chen <cwher@outlook.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
Co-authored-by: Zhongkai Zhao <kanezz620@gmail.com>
Co-authored-by: Michelle <97082656+MichelleMa8@users.noreply.github.com>
Co-authored-by: Desperado-Jia <502205863@qq.com>
2024-03-04 16:18:13 +08:00

188 lines
6.5 KiB
Python

from contextlib import nullcontext
from typing import Optional
import pytest
import torch
import torch.distributed as dist
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin
from colossalai.fx import is_compatible_with_meta
from colossalai.lazy.lazy_init import LazyInitContext
from colossalai.nn.optimizer import HybridAdam
from colossalai.tensor.colo_parameter import ColoParameter
from colossalai.testing import (
clear_cache_before_run,
parameterize,
rerun_if_address_is_in_use,
skip_if_not_enough_gpus,
spawn,
)
from tests.kit.model_zoo import COMMON_MODELS, IS_FAST_TEST, model_zoo
@clear_cache_before_run()
def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn, zero_size, tp_size) -> Optional[str]:
try:
if init_method == "lazy":
ctx = LazyInitContext()
else:
ctx = nullcontext()
extra_dp_size = dist.get_world_size() // (zero_size * tp_size)
enable_all_optimization = True if tp_size > 1 else False
plugin = GeminiPlugin(
max_norm=1.0,
initial_scale=2**5,
tp_size=tp_size,
extra_dp_size=extra_dp_size,
enable_all_optimization=enable_all_optimization,
)
booster = Booster(plugin=plugin)
with ctx:
model = model_fn()
optimizer = HybridAdam(model.parameters(), lr=1e-3)
criterion = lambda x: x.mean()
data = data_gen_fn()
data = {
k: v.to("cuda") if torch.is_tensor(v) or "Tensor" in v.__class__.__name__ else v for k, v in data.items()
}
model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)
for n, p in model.named_parameters():
assert isinstance(p, ColoParameter), f"{n} is not a ColoParameter"
output = model(**data)
output = output_transform_fn(output)
output_key = list(output.keys())[0]
loss = criterion(output[output_key])
booster.backward(loss, optimizer)
optimizer.step()
except NotImplementedError:
print(f"Tensor Parallelism policy for {model.__class__} is not implemented yet\n.")
except Exception as e:
# raise e
return repr(e)
# TODO(ver217): CI does not support lazy now
# @parameterize('init_method', ['lazy', 'none', 'colo'])
@parameterize("subset", [COMMON_MODELS] if IS_FAST_TEST else ["torchvision", "transformers", "diffusers"])
@parameterize("init_method", ["none"])
@parameterize("zero_size", [2])
@parameterize("tp_size", [2])
def check_gemini_plugin(
subset: str, init_method: str = "none", early_stop: bool = True, zero_size: int = 1, tp_size: int = 1
):
"""check gemini plugin over model zoo
Args:
early_stop (bool, optional): Whether to stop when getting the first error. Defaults to True.
"""
is_support_meta = is_compatible_with_meta()
if not is_support_meta and init_method == "lazy":
return
passed_models = []
failed_info = {} # (model_name, error) pair
for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in model_zoo.get_sub_registry(subset).items():
# These models lead to CUDA error
if name in (
"diffusers_auto_encoder_kl",
"diffusers_vq_model",
"diffusers_unet2d_model",
"timm_resmlp",
"timm_gmixer_12_224",
"timm_gmlp_b16_224",
"timm_mixer_b16_224",
"timm_convnext",
"torchvision_convnext_base",
):
continue
# These models are not compatible with gemini
if name in [
"timm_convit",
"timm_dm_nfnet",
"torchvision_vit_b_16",
"transformers_t5",
"transformers_t5_for_conditional_generation",
"transformers_t5_encoder_model", # does not support apex rmsnorm
"transformers_chatglm",
"transformers_sam",
"transformers_vit",
"transformers_gpt_double_heads", # TODO check why does the model fail to run using Gemini
"transformers_falcon", # TODO check why falcon fails to run Gemini
"transformers_falcon_for_causal_lm",
"transformers_falcon_for_sequence_classification",
"transformers_falcon_for_token_classification",
"transformers_falcon_for_question_answering",
"transformers_gptj_lm", # lead to OOM when running in ci
"transformers_gptj_for_question_answering",
"transformers_gptj_for_sequence_classification",
]:
continue
if init_method == "lazy" and name in [
"timm_convmixer",
"timm_vision_transformer",
"timm_deit",
"timm_deit3",
"timm_inception_v3",
"timm_tnt_b_patch16_224",
"timm_rexnet",
"torchvision_densenet121",
"torchvision_efficientnet_b0",
"torchvision_mobilenet_v2",
"torchvision_mnasnet0_5",
"torchvision_regnet_x_16gf",
"torchvision_shufflenet_v2_x0_5",
"torchvision_efficientnet_v2_s",
]:
continue
# TODO debug blip2 when using tp, something wrong with shift_logits's shape
if "transformers_blip2" in name:
tp_size = 1
err = run_fn(init_method, model_fn, data_gen_fn, output_transform_fn, zero_size, tp_size)
if err is None:
passed_models.append(name)
else:
failed_info[name] = err
if early_stop:
break
if dist.get_rank() == 0:
print(f"Init method: {init_method}")
print(f"Passed models({len(passed_models)}): {passed_models}\n\n")
print(f"Failed models({len(failed_info)}): {list(failed_info.keys())}\n\n")
assert len(failed_info) == 0, "\n".join([f"{k}: {v}" for k, v in failed_info.items()])
def run_dist(rank, world_size, port, early_stop: bool = True):
# init dist env
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost")
check_gemini_plugin(early_stop=early_stop)
@rerun_if_address_is_in_use()
def test_gemini_plugin(early_stop: bool = True):
spawn(run_dist, 4, early_stop=early_stop)
@pytest.mark.largedist
@skip_if_not_enough_gpus(8)
@rerun_if_address_is_in_use()
def test_gemini_plugin_3d(early_stop: bool = True):
spawn(run_dist, 8, early_stop=early_stop)
if __name__ == "__main__":
test_gemini_plugin(early_stop=False)