[checkpointio] fix gemini and hybrid parallel optim checkpoint (#5347)

* [checkpointio] fix hybrid parallel optim checkpoint

* [extension] fix cuda extension

* [checkpointio] fix gemini optimizer checkpoint

* polish code
This commit is contained in:
Hongxin Liu
2024-02-01 16:13:06 +08:00
committed by GitHub
parent c5239840e6
commit ffffc32dc7
5 changed files with 35 additions and 8 deletions

View File

@@ -97,7 +97,7 @@ def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_sha
new_model = model_fn()
optimizer = HybridAdam(model.parameters(), lr=0.001)
model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)
new_optimizer = HybridAdam(new_model.parameters(), lr=0.001)
new_optimizer = HybridAdam(new_model.parameters(), lr=0.01)
new_model, new_optimizer, criterion, _, _ = booster.boost(new_model, new_optimizer, criterion)
data = data_gen_fn()
@@ -109,6 +109,8 @@ def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_sha
booster.backward(loss, optimizer)
optimizer.step()
for group in optimizer.param_groups:
group["lr"] = 0.1
with shared_tempdir() as tempdir:
model_ckpt_path = f"{tempdir}/model"
@@ -127,6 +129,8 @@ def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_sha
check_state_dict_equal(
optimizer.state_dict(only_rank_0=False), new_optimizer.state_dict(only_rank_0=False), False
)
for group in new_optimizer.param_groups:
assert group["lr"] == 0.1
# Check the new model/optimizer can successfully run.
data = data_gen_fn()