[checkpointio] Sharded Optimizer Checkpoint for Gemini Plugin (#4302)

* sharded optimizer checkpoint for gemini plugin

* modify test to reduce testing time

* update doc

* fix bug when keep_gatherd is true under GeminiPlugin
This commit is contained in:
Baizhou Zhang
2023-07-21 14:39:01 +08:00
committed by GitHub
parent fc5cef2c79
commit c6f6005990
12 changed files with 289 additions and 84 deletions

View File

@@ -52,7 +52,7 @@ def exam_state_dict_with_origin(placement_policy, model_name, use_safetensors: b
@clear_cache_before_run()
@parameterize('placement_policy', ['cuda', 'cpu'])
@parameterize('shard', [False])
@parameterize('shard', [False, True])
@parameterize('model_name', ['transformers_gpt'])
@parameterize('size_per_shard', [32])
def exam_state_dict(placement_policy, shard: bool, model_name: str, size_per_shard: int):
@@ -117,7 +117,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 2])
@pytest.mark.parametrize('world_size', [2])
@rerun_if_address_is_in_use()
def test_gemini_ckpIO(world_size):
spawn(run_dist, world_size)