[checkpointio] Sharded Optimizer Checkpoint for Gemini Plugin (#4302)

* sharded optimizer checkpoint for gemini plugin

* modify test to reduce testing time

* update doc

* fix bug when keep_gatherd is true under GeminiPlugin
This commit is contained in:
Baizhou Zhang
2023-07-21 14:39:01 +08:00
committed by GitHub
parent fc5cef2c79
commit c6f6005990
12 changed files with 289 additions and 84 deletions

View File

@@ -19,7 +19,7 @@ from tests.kit.model_zoo import model_zoo
@clear_cache_before_run()
@parameterize('shard', [False])
@parameterize('shard', [False, True])
@parameterize('model_name', ['transformers_gpt'])
def exam_torch_load_from_gemini(shard: bool, model_name: str):
@@ -83,7 +83,7 @@ def exam_torch_load_from_gemini(shard: bool, model_name: str):
@clear_cache_before_run()
@parameterize('shard', [False])
@parameterize('shard', [False, True])
@parameterize('model_name', ['transformers_gpt'])
def exam_gemini_load_from_torch(shard: bool, model_name: str):
@@ -165,7 +165,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 2])
@pytest.mark.parametrize('world_size', [2])
@rerun_if_address_is_in_use()
def test_gemini_ckpIO(world_size):
spawn(run_dist, world_size)