mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 20:10:17 +00:00
[checkpointio] Sharded Optimizer Checkpoint for Gemini Plugin (#4302)
* sharded optimizer checkpoint for gemini plugin * modify test to reduce testing time * update doc * fix bug when keep_gatherd is true under GeminiPlugin
This commit is contained in:
@@ -52,7 +52,7 @@ def exam_state_dict_with_origin(placement_policy, model_name, use_safetensors: b
|
||||
|
||||
@clear_cache_before_run()
|
||||
@parameterize('placement_policy', ['cuda', 'cpu'])
|
||||
@parameterize('shard', [False])
|
||||
@parameterize('shard', [False, True])
|
||||
@parameterize('model_name', ['transformers_gpt'])
|
||||
@parameterize('size_per_shard', [32])
|
||||
def exam_state_dict(placement_policy, shard: bool, model_name: str, size_per_shard: int):
|
||||
@@ -117,7 +117,7 @@ def run_dist(rank, world_size, port):
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [1, 2])
|
||||
@pytest.mark.parametrize('world_size', [2])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_gemini_ckpIO(world_size):
|
||||
spawn(run_dist, world_size)
|
||||
|
Reference in New Issue
Block a user