mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 21:09:18 +00:00
[checkpointio] fix gemini and hybrid parallel optim checkpoint (#5347)
* [checkpointio] fix hybrid parallel optim checkpoint * [extension] fix cuda extension * [checkpointio] fix gemini optimizer checkpoint * polish code
This commit is contained in:
@@ -97,7 +97,7 @@ def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_sha
|
||||
new_model = model_fn()
|
||||
optimizer = HybridAdam(model.parameters(), lr=0.001)
|
||||
model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)
|
||||
new_optimizer = HybridAdam(new_model.parameters(), lr=0.001)
|
||||
new_optimizer = HybridAdam(new_model.parameters(), lr=0.01)
|
||||
new_model, new_optimizer, criterion, _, _ = booster.boost(new_model, new_optimizer, criterion)
|
||||
|
||||
data = data_gen_fn()
|
||||
@@ -109,6 +109,8 @@ def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_sha
|
||||
|
||||
booster.backward(loss, optimizer)
|
||||
optimizer.step()
|
||||
for group in optimizer.param_groups:
|
||||
group["lr"] = 0.1
|
||||
|
||||
with shared_tempdir() as tempdir:
|
||||
model_ckpt_path = f"{tempdir}/model"
|
||||
@@ -127,6 +129,8 @@ def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_sha
|
||||
check_state_dict_equal(
|
||||
optimizer.state_dict(only_rank_0=False), new_optimizer.state_dict(only_rank_0=False), False
|
||||
)
|
||||
for group in new_optimizer.param_groups:
|
||||
assert group["lr"] == 0.1
|
||||
|
||||
# Check the new model/optimizer can successfully run.
|
||||
data = data_gen_fn()
|
||||
|
Reference in New Issue
Block a user