mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-16 06:30:41 +00:00
[gemini] gemini support extra-dp (#5043)
* support ddp * fix * fix * fix fix * support ddp * fix * fix * fix fix * simplify tests * fix * fix * fix fix fix * fix
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
from contextlib import nullcontext
|
||||
from typing import Optional
|
||||
import pytest
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@@ -17,14 +18,15 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
|
||||
|
||||
def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn, enable_tensor_parallelism) -> Optional[str]:
|
||||
def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn, zero_size, tp_size) -> Optional[str]:
|
||||
try:
|
||||
if init_method == "lazy":
|
||||
ctx = LazyInitContext()
|
||||
else:
|
||||
ctx = nullcontext()
|
||||
enable_all_optimization = True if enable_tensor_parallelism else False
|
||||
plugin = GeminiPlugin(max_norm=1.0, initial_scale=2**5, enable_tensor_parallelism=enable_tensor_parallelism, enable_all_optimization=enable_all_optimization)
|
||||
extra_dp_size = dist.get_world_size() // (zero_size * tp_size)
|
||||
enable_all_optimization = True if tp_size > 1 else False
|
||||
plugin = GeminiPlugin(max_norm=1.0, initial_scale=2**5, tp_size=tp_size, extra_dp_size=extra_dp_size, enable_all_optimization=enable_all_optimization)
|
||||
booster = Booster(plugin=plugin)
|
||||
with ctx:
|
||||
model = model_fn()
|
||||
@@ -62,8 +64,9 @@ def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn, enable_tenso
|
||||
|
||||
@parameterize("subset", ["torchvision", "transformers", "diffusers"])
|
||||
@parameterize("init_method", ["none"])
|
||||
@parameterize("enable_tensor_parallelism", [True, False])
|
||||
def check_gemini_plugin(subset: str, init_method: str = "none", enable_tensor_parallelism: bool = True, early_stop: bool = True):
|
||||
@parameterize("zero_size", [2])
|
||||
@parameterize("tp_size", [2])
|
||||
def check_gemini_plugin(subset: str, init_method: str = "none", early_stop: bool = True, zero_size: int = 1, tp_size: int = 1):
|
||||
"""check gemini plugin over model zoo
|
||||
|
||||
Args:
|
||||
@@ -125,9 +128,9 @@ def check_gemini_plugin(subset: str, init_method: str = "none", enable_tensor_pa
|
||||
|
||||
# TODO debug blip2 when using tp, something wrong with shift_logits's shape
|
||||
if "transformers_blip2" in name:
|
||||
enable_tensor_parallelism = False
|
||||
tp_size = 1
|
||||
|
||||
err = run_fn(init_method, model_fn, data_gen_fn, output_transform_fn, enable_tensor_parallelism)
|
||||
err = run_fn(init_method, model_fn, data_gen_fn, output_transform_fn, zero_size, tp_size)
|
||||
torch.cuda.empty_cache()
|
||||
if err is None:
|
||||
passed_models.append(name)
|
||||
@@ -153,6 +156,11 @@ def run_dist(rank, world_size, port, early_stop: bool = True):
|
||||
def test_gemini_plugin(early_stop: bool = True):
|
||||
spawn(run_dist, 4, early_stop=early_stop)
|
||||
|
||||
@pytest.mark.largedist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_gemini_plugin_3d(early_stop: bool = True):
|
||||
spawn(run_dist, 8, early_stop=early_stop)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_gemini_plugin(early_stop=False)
|
||||
test_gemini_plugin(early_stop=False)
|
@@ -37,20 +37,21 @@ OPTIM_PLACEMENT_CONFIGS = [
|
||||
@parameterize("placement_config", MODEL_PLACEMENT_CONFIGS)
|
||||
@parameterize("model_name", ["transformers_bert_for_sequence_classification"])
|
||||
@parameterize("use_safetensors", [False, True])
|
||||
@parameterize("enable_tensor_parallelism", [True, False])
|
||||
@parameterize("tp_size", [2])
|
||||
def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: bool, enable_tensor_parallelism: bool, tp_size: int):
|
||||
@parameterize("tp_size", [1, 2])
|
||||
@parameterize("zero_size", [2])
|
||||
def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: bool, tp_size: int, zero_size: int):
|
||||
from transformers import BertForSequenceClassification
|
||||
|
||||
(model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values()))
|
||||
bert_model = model_fn()
|
||||
enable_all_optimization = True if enable_tensor_parallelism else False
|
||||
enable_all_optimization = True if tp_size > 1 else False
|
||||
|
||||
with shared_tempdir() as tempdir:
|
||||
pretrained_path = os.path.join(tempdir, "pretrained")
|
||||
bert_model.config.save_pretrained(save_directory=pretrained_path)
|
||||
|
||||
plugin = GeminiPlugin(**placement_config, enable_tensor_parallelism=enable_tensor_parallelism, tp_size=tp_size, enable_all_optimization=enable_all_optimization)
|
||||
extra_dp_size = dist.get_world_size() // (zero_size * tp_size)
|
||||
plugin = GeminiPlugin(**placement_config, tp_size=tp_size, enable_all_optimization=enable_all_optimization, extra_dp_size=extra_dp_size)
|
||||
booster = Booster(plugin=plugin)
|
||||
bert_model, _, _, _, _ = booster.boost(bert_model)
|
||||
model_size = sum(p.numel() * p.element_size() for p in bert_model.parameters()) / 1024**2
|
||||
@@ -69,13 +70,14 @@ def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: b
|
||||
@parameterize("shard", [True, False])
|
||||
@parameterize("model_name", ["transformers_gpt"])
|
||||
@parameterize("size_per_shard", [32])
|
||||
@parameterize("enable_tensor_parallelism", [True, False])
|
||||
@parameterize("tp_size", [2])
|
||||
def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_shard: int, enable_tensor_parallelism: bool, tp_size: int):
|
||||
@parameterize("tp_size", [1, 2])
|
||||
@parameterize("zero_size", [2])
|
||||
def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_shard: int, tp_size: int, zero_size: int):
|
||||
(model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values()))
|
||||
criterion = lambda x: x.mean()
|
||||
enable_all_optimization = True if enable_tensor_parallelism else False
|
||||
plugin = GeminiPlugin(**placement_config, precision="fp16", initial_scale=(2**14), enable_tensor_parallelism=enable_tensor_parallelism, tp_size=tp_size, enable_all_optimization=enable_all_optimization)
|
||||
enable_all_optimization = True if tp_size > 1 else False
|
||||
extra_dp_size = dist.get_world_size() // (zero_size * tp_size)
|
||||
plugin = GeminiPlugin(**placement_config, precision="fp16", initial_scale=(2**14), tp_size=tp_size, extra_dp_size=extra_dp_size, enable_all_optimization=enable_all_optimization)
|
||||
booster = Booster(plugin=plugin)
|
||||
|
||||
model = model_fn()
|
||||
@@ -158,3 +160,9 @@ def run_dist(rank, world_size, port):
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_gemini_ckpIO(world_size):
|
||||
spawn(run_dist, world_size)
|
||||
|
||||
@pytest.mark.largedist
|
||||
@pytest.mark.parametrize("world_size", [8])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_gemini_ckpIO_3d(world_size):
|
||||
spawn(run_dist, world_size)
|
@@ -124,25 +124,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||
@parameterize(
|
||||
"test_config",
|
||||
[
|
||||
{
|
||||
"tp_size": 1,
|
||||
"pp_size": 2,
|
||||
"num_microbatches": 4,
|
||||
"enable_all_optimization": True,
|
||||
"use_lazy_init": True,
|
||||
"precision": "fp16",
|
||||
"max_norm": 5,
|
||||
"initial_scale": 1,
|
||||
},
|
||||
{
|
||||
"tp_size": 2,
|
||||
"pp_size": 1,
|
||||
"enable_all_optimization": True,
|
||||
"use_lazy_init": False,
|
||||
"precision": "fp16",
|
||||
"max_norm": 5,
|
||||
"initial_scale": 1,
|
||||
},
|
||||
{
|
||||
"tp_size": 2,
|
||||
"pp_size": 2,
|
||||
@@ -153,23 +134,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||
"max_norm": 5,
|
||||
"initial_scale": 1,
|
||||
},
|
||||
{
|
||||
"tp_size": 1,
|
||||
"pp_size": 2,
|
||||
"num_microbatches": 4,
|
||||
"enable_all_optimization": True,
|
||||
"use_lazy_init": True,
|
||||
"precision": "bf16",
|
||||
"max_norm": 5,
|
||||
},
|
||||
{
|
||||
"tp_size": 2,
|
||||
"pp_size": 1,
|
||||
"enable_all_optimization": True,
|
||||
"use_lazy_init": False,
|
||||
"precision": "bf16",
|
||||
"max_norm": 5,
|
||||
},
|
||||
{
|
||||
"tp_size": 2,
|
||||
"pp_size": 2,
|
||||
|
@@ -102,28 +102,11 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||
@parameterize(
|
||||
"test_config",
|
||||
[
|
||||
{
|
||||
"tp_size": 1,
|
||||
"pp_size": 2,
|
||||
"num_microbatches": 4,
|
||||
"enable_all_optimization": True,
|
||||
"use_lazy_init": True,
|
||||
"precision": "fp32",
|
||||
"max_norm": 5,
|
||||
},
|
||||
{
|
||||
"tp_size": 2,
|
||||
"pp_size": 1,
|
||||
"enable_all_optimization": True,
|
||||
"use_lazy_init": False,
|
||||
"precision": "fp32",
|
||||
"max_norm": 5,
|
||||
},
|
||||
{
|
||||
"tp_size": 2,
|
||||
"pp_size": 2,
|
||||
"num_microbatches": 4,
|
||||
"enable_all_optimization": True,
|
||||
"enable_all_optimization": False,
|
||||
"use_lazy_init": False,
|
||||
"precision": "fp32",
|
||||
"max_norm": 5,
|
||||
@@ -148,7 +131,7 @@ def run_test(test_config):
|
||||
"tp_size": 2,
|
||||
"pp_size": 2,
|
||||
"num_microbatches": 4,
|
||||
"enable_all_optimization": True,
|
||||
"enable_all_optimization": False,
|
||||
"use_lazy_init": False,
|
||||
"precision": "fp32",
|
||||
"max_norm": 5,
|
||||
|
@@ -106,17 +106,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||
"pp_size": 2,
|
||||
"num_microbatches": 4,
|
||||
"zero_stage": 1,
|
||||
"enable_all_optimization": True,
|
||||
"use_lazy_init": True,
|
||||
"precision": "fp16",
|
||||
"max_norm": 5,
|
||||
"initial_scale": 1,
|
||||
},
|
||||
{
|
||||
"tp_size": 2,
|
||||
"pp_size": 1,
|
||||
"zero_stage": 1,
|
||||
"enable_all_optimization": True,
|
||||
"enable_all_optimization": False,
|
||||
"use_lazy_init": False,
|
||||
"precision": "fp16",
|
||||
"max_norm": 5,
|
||||
@@ -126,36 +116,17 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||
"tp_size": 2,
|
||||
"pp_size": 1,
|
||||
"zero_stage": 2,
|
||||
"enable_all_optimization": True,
|
||||
"enable_all_optimization": False,
|
||||
"use_lazy_init": False,
|
||||
"precision": "fp16",
|
||||
"max_norm": 5,
|
||||
"initial_scale": 1,
|
||||
},
|
||||
{
|
||||
"tp_size": 1,
|
||||
"pp_size": 2,
|
||||
"num_microbatches": 4,
|
||||
"zero_stage": 1,
|
||||
"enable_all_optimization": True,
|
||||
"use_lazy_init": True,
|
||||
"precision": "bf16",
|
||||
"max_norm": 5,
|
||||
},
|
||||
{
|
||||
"tp_size": 2,
|
||||
"pp_size": 1,
|
||||
"zero_stage": 1,
|
||||
"enable_all_optimization": True,
|
||||
"use_lazy_init": False,
|
||||
"precision": "bf16",
|
||||
"max_norm": 5,
|
||||
},
|
||||
{
|
||||
"tp_size": 2,
|
||||
"pp_size": 1,
|
||||
"zero_stage": 2,
|
||||
"enable_all_optimization": True,
|
||||
"enable_all_optimization": False,
|
||||
"use_lazy_init": False,
|
||||
"precision": "bf16",
|
||||
"max_norm": 5,
|
||||
|
@@ -39,7 +39,7 @@ def exam_chunk_basic(init_device, keep_gathered, pin_memory):
|
||||
pg = _get_default_group()
|
||||
my_chunk = Chunk(
|
||||
chunk_size=1024,
|
||||
process_group=pg,
|
||||
zero_group=pg,
|
||||
dtype=torch.float32,
|
||||
init_device=init_device,
|
||||
cpu_shard_init=True,
|
||||
|
Reference in New Issue
Block a user