[gemini] gemini support extra-dp (#5043)

* support ddp

* fix

* fix

* fix

fix

* support ddp

* fix

* fix

* fix

fix

* simplify tests

* fix

* fix

* fix

fix

fix

* fix
This commit is contained in:
flybird11111
2023-11-16 21:03:04 +08:00
committed by GitHub
parent b2ad0d9e8f
commit 3e02154710
10 changed files with 96 additions and 137 deletions

View File

@@ -37,20 +37,21 @@ OPTIM_PLACEMENT_CONFIGS = [
@parameterize("placement_config", MODEL_PLACEMENT_CONFIGS)
@parameterize("model_name", ["transformers_bert_for_sequence_classification"])
@parameterize("use_safetensors", [False, True])
@parameterize("enable_tensor_parallelism", [True, False])
@parameterize("tp_size", [2])
def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: bool, enable_tensor_parallelism: bool, tp_size: int):
@parameterize("tp_size", [1, 2])
@parameterize("zero_size", [2])
def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: bool, tp_size: int, zero_size: int):
from transformers import BertForSequenceClassification
(model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values()))
bert_model = model_fn()
enable_all_optimization = True if enable_tensor_parallelism else False
enable_all_optimization = True if tp_size > 1 else False
with shared_tempdir() as tempdir:
pretrained_path = os.path.join(tempdir, "pretrained")
bert_model.config.save_pretrained(save_directory=pretrained_path)
plugin = GeminiPlugin(**placement_config, enable_tensor_parallelism=enable_tensor_parallelism, tp_size=tp_size, enable_all_optimization=enable_all_optimization)
extra_dp_size = dist.get_world_size() // (zero_size * tp_size)
plugin = GeminiPlugin(**placement_config, tp_size=tp_size, enable_all_optimization=enable_all_optimization, extra_dp_size=extra_dp_size)
booster = Booster(plugin=plugin)
bert_model, _, _, _, _ = booster.boost(bert_model)
model_size = sum(p.numel() * p.element_size() for p in bert_model.parameters()) / 1024**2
@@ -69,13 +70,14 @@ def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: b
@parameterize("shard", [True, False])
@parameterize("model_name", ["transformers_gpt"])
@parameterize("size_per_shard", [32])
@parameterize("enable_tensor_parallelism", [True, False])
@parameterize("tp_size", [2])
def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_shard: int, enable_tensor_parallelism: bool, tp_size: int):
@parameterize("tp_size", [1, 2])
@parameterize("zero_size", [2])
def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_shard: int, tp_size: int, zero_size: int):
(model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values()))
criterion = lambda x: x.mean()
enable_all_optimization = True if enable_tensor_parallelism else False
plugin = GeminiPlugin(**placement_config, precision="fp16", initial_scale=(2**14), enable_tensor_parallelism=enable_tensor_parallelism, tp_size=tp_size, enable_all_optimization=enable_all_optimization)
enable_all_optimization = True if tp_size > 1 else False
extra_dp_size = dist.get_world_size() // (zero_size * tp_size)
plugin = GeminiPlugin(**placement_config, precision="fp16", initial_scale=(2**14), tp_size=tp_size, extra_dp_size=extra_dp_size, enable_all_optimization=enable_all_optimization)
booster = Booster(plugin=plugin)
model = model_fn()
@@ -158,3 +160,9 @@ def run_dist(rank, world_size, port):
@rerun_if_address_is_in_use()
def test_gemini_ckpIO(world_size):
spawn(run_dist, world_size)
@pytest.mark.largedist
@pytest.mark.parametrize("world_size", [8])
@rerun_if_address_is_in_use()
def test_gemini_ckpIO_3d(world_size):
spawn(run_dist, world_size)