[gemini] gemini support extra-dp (#5043)

* support ddp

* fix

* fix

* fix

fix

* support ddp

* fix

* fix

* fix

fix

* simplify tests

* fix

* fix

* fix

fix

fix

* fix
This commit is contained in:
flybird11111
2023-11-16 21:03:04 +08:00
committed by GitHub
parent b2ad0d9e8f
commit 3e02154710
10 changed files with 96 additions and 137 deletions

View File

@@ -1,5 +1,6 @@
from contextlib import nullcontext
from typing import Optional
import pytest
import torch
import torch.distributed as dist
@@ -17,14 +18,15 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn, enable_tensor_parallelism) -> Optional[str]:
def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn, zero_size, tp_size) -> Optional[str]:
try:
if init_method == "lazy":
ctx = LazyInitContext()
else:
ctx = nullcontext()
enable_all_optimization = True if enable_tensor_parallelism else False
plugin = GeminiPlugin(max_norm=1.0, initial_scale=2**5, enable_tensor_parallelism=enable_tensor_parallelism, enable_all_optimization=enable_all_optimization)
extra_dp_size = dist.get_world_size() // (zero_size * tp_size)
enable_all_optimization = True if tp_size > 1 else False
plugin = GeminiPlugin(max_norm=1.0, initial_scale=2**5, tp_size=tp_size, extra_dp_size=extra_dp_size, enable_all_optimization=enable_all_optimization)
booster = Booster(plugin=plugin)
with ctx:
model = model_fn()
@@ -62,8 +64,9 @@ def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn, enable_tenso
@parameterize("subset", ["torchvision", "transformers", "diffusers"])
@parameterize("init_method", ["none"])
@parameterize("enable_tensor_parallelism", [True, False])
def check_gemini_plugin(subset: str, init_method: str = "none", enable_tensor_parallelism: bool = True, early_stop: bool = True):
@parameterize("zero_size", [2])
@parameterize("tp_size", [2])
def check_gemini_plugin(subset: str, init_method: str = "none", early_stop: bool = True, zero_size: int = 1, tp_size: int = 1):
"""check gemini plugin over model zoo
Args:
@@ -125,9 +128,9 @@ def check_gemini_plugin(subset: str, init_method: str = "none", enable_tensor_pa
# TODO debug blip2 when using tp, something wrong with shift_logits's shape
if "transformers_blip2" in name:
enable_tensor_parallelism = False
tp_size = 1
err = run_fn(init_method, model_fn, data_gen_fn, output_transform_fn, enable_tensor_parallelism)
err = run_fn(init_method, model_fn, data_gen_fn, output_transform_fn, zero_size, tp_size)
torch.cuda.empty_cache()
if err is None:
passed_models.append(name)
@@ -153,6 +156,11 @@ def run_dist(rank, world_size, port, early_stop: bool = True):
def test_gemini_plugin(early_stop: bool = True):
spawn(run_dist, 4, early_stop=early_stop)
@pytest.mark.largedist
@rerun_if_address_is_in_use()
def test_gemini_plugin_3d(early_stop: bool = True):
spawn(run_dist, 8, early_stop=early_stop)
if __name__ == "__main__":
test_gemini_plugin(early_stop=False)
test_gemini_plugin(early_stop=False)

View File

@@ -37,20 +37,21 @@ OPTIM_PLACEMENT_CONFIGS = [
@parameterize("placement_config", MODEL_PLACEMENT_CONFIGS)
@parameterize("model_name", ["transformers_bert_for_sequence_classification"])
@parameterize("use_safetensors", [False, True])
@parameterize("enable_tensor_parallelism", [True, False])
@parameterize("tp_size", [2])
def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: bool, enable_tensor_parallelism: bool, tp_size: int):
@parameterize("tp_size", [1, 2])
@parameterize("zero_size", [2])
def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: bool, tp_size: int, zero_size: int):
from transformers import BertForSequenceClassification
(model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values()))
bert_model = model_fn()
enable_all_optimization = True if enable_tensor_parallelism else False
enable_all_optimization = True if tp_size > 1 else False
with shared_tempdir() as tempdir:
pretrained_path = os.path.join(tempdir, "pretrained")
bert_model.config.save_pretrained(save_directory=pretrained_path)
plugin = GeminiPlugin(**placement_config, enable_tensor_parallelism=enable_tensor_parallelism, tp_size=tp_size, enable_all_optimization=enable_all_optimization)
extra_dp_size = dist.get_world_size() // (zero_size * tp_size)
plugin = GeminiPlugin(**placement_config, tp_size=tp_size, enable_all_optimization=enable_all_optimization, extra_dp_size=extra_dp_size)
booster = Booster(plugin=plugin)
bert_model, _, _, _, _ = booster.boost(bert_model)
model_size = sum(p.numel() * p.element_size() for p in bert_model.parameters()) / 1024**2
@@ -69,13 +70,14 @@ def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: b
@parameterize("shard", [True, False])
@parameterize("model_name", ["transformers_gpt"])
@parameterize("size_per_shard", [32])
@parameterize("enable_tensor_parallelism", [True, False])
@parameterize("tp_size", [2])
def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_shard: int, enable_tensor_parallelism: bool, tp_size: int):
@parameterize("tp_size", [1, 2])
@parameterize("zero_size", [2])
def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_shard: int, tp_size: int, zero_size: int):
(model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values()))
criterion = lambda x: x.mean()
enable_all_optimization = True if enable_tensor_parallelism else False
plugin = GeminiPlugin(**placement_config, precision="fp16", initial_scale=(2**14), enable_tensor_parallelism=enable_tensor_parallelism, tp_size=tp_size, enable_all_optimization=enable_all_optimization)
enable_all_optimization = True if tp_size > 1 else False
extra_dp_size = dist.get_world_size() // (zero_size * tp_size)
plugin = GeminiPlugin(**placement_config, precision="fp16", initial_scale=(2**14), tp_size=tp_size, extra_dp_size=extra_dp_size, enable_all_optimization=enable_all_optimization)
booster = Booster(plugin=plugin)
model = model_fn()
@@ -158,3 +160,9 @@ def run_dist(rank, world_size, port):
@rerun_if_address_is_in_use()
def test_gemini_ckpIO(world_size):
spawn(run_dist, world_size)
@pytest.mark.largedist
@pytest.mark.parametrize("world_size", [8])
@rerun_if_address_is_in_use()
def test_gemini_ckpIO_3d(world_size):
spawn(run_dist, world_size)

View File

@@ -124,25 +124,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
@parameterize(
"test_config",
[
{
"tp_size": 1,
"pp_size": 2,
"num_microbatches": 4,
"enable_all_optimization": True,
"use_lazy_init": True,
"precision": "fp16",
"max_norm": 5,
"initial_scale": 1,
},
{
"tp_size": 2,
"pp_size": 1,
"enable_all_optimization": True,
"use_lazy_init": False,
"precision": "fp16",
"max_norm": 5,
"initial_scale": 1,
},
{
"tp_size": 2,
"pp_size": 2,
@@ -153,23 +134,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"max_norm": 5,
"initial_scale": 1,
},
{
"tp_size": 1,
"pp_size": 2,
"num_microbatches": 4,
"enable_all_optimization": True,
"use_lazy_init": True,
"precision": "bf16",
"max_norm": 5,
},
{
"tp_size": 2,
"pp_size": 1,
"enable_all_optimization": True,
"use_lazy_init": False,
"precision": "bf16",
"max_norm": 5,
},
{
"tp_size": 2,
"pp_size": 2,

View File

@@ -102,28 +102,11 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
@parameterize(
"test_config",
[
{
"tp_size": 1,
"pp_size": 2,
"num_microbatches": 4,
"enable_all_optimization": True,
"use_lazy_init": True,
"precision": "fp32",
"max_norm": 5,
},
{
"tp_size": 2,
"pp_size": 1,
"enable_all_optimization": True,
"use_lazy_init": False,
"precision": "fp32",
"max_norm": 5,
},
{
"tp_size": 2,
"pp_size": 2,
"num_microbatches": 4,
"enable_all_optimization": True,
"enable_all_optimization": False,
"use_lazy_init": False,
"precision": "fp32",
"max_norm": 5,
@@ -148,7 +131,7 @@ def run_test(test_config):
"tp_size": 2,
"pp_size": 2,
"num_microbatches": 4,
"enable_all_optimization": True,
"enable_all_optimization": False,
"use_lazy_init": False,
"precision": "fp32",
"max_norm": 5,

View File

@@ -106,17 +106,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"pp_size": 2,
"num_microbatches": 4,
"zero_stage": 1,
"enable_all_optimization": True,
"use_lazy_init": True,
"precision": "fp16",
"max_norm": 5,
"initial_scale": 1,
},
{
"tp_size": 2,
"pp_size": 1,
"zero_stage": 1,
"enable_all_optimization": True,
"enable_all_optimization": False,
"use_lazy_init": False,
"precision": "fp16",
"max_norm": 5,
@@ -126,36 +116,17 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"tp_size": 2,
"pp_size": 1,
"zero_stage": 2,
"enable_all_optimization": True,
"enable_all_optimization": False,
"use_lazy_init": False,
"precision": "fp16",
"max_norm": 5,
"initial_scale": 1,
},
{
"tp_size": 1,
"pp_size": 2,
"num_microbatches": 4,
"zero_stage": 1,
"enable_all_optimization": True,
"use_lazy_init": True,
"precision": "bf16",
"max_norm": 5,
},
{
"tp_size": 2,
"pp_size": 1,
"zero_stage": 1,
"enable_all_optimization": True,
"use_lazy_init": False,
"precision": "bf16",
"max_norm": 5,
},
{
"tp_size": 2,
"pp_size": 1,
"zero_stage": 2,
"enable_all_optimization": True,
"enable_all_optimization": False,
"use_lazy_init": False,
"precision": "bf16",
"max_norm": 5,

View File

@@ -39,7 +39,7 @@ def exam_chunk_basic(init_device, keep_gathered, pin_memory):
pg = _get_default_group()
my_chunk = Chunk(
chunk_size=1024,
process_group=pg,
zero_group=pg,
dtype=torch.float32,
init_device=init_device,
cpu_shard_init=True,