mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 11:02:05 +00:00
[optim] hotfix adam load (#6146)
* [optim] hotfix adam load * [checkpointio] fix optimizer async io * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [checkpointio] update test * [checkpointio] update test --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -1,9 +1,9 @@
|
||||
import tempfile
|
||||
from copy import deepcopy
|
||||
|
||||
import torch
|
||||
from safetensors.torch import load_file
|
||||
|
||||
from colossalai.utils.safetensors import load_flat, save_nested
|
||||
from colossalai.utils.safetensors import load_flat, move_and_save, save, save_nested
|
||||
|
||||
try:
|
||||
from tensornvme.async_file_io import AsyncFileWriter
|
||||
@@ -11,17 +11,29 @@ except ModuleNotFoundError:
|
||||
raise ModuleNotFoundError("Please install tensornvme to use NVMeOptimizer")
|
||||
|
||||
from colossalai.testing import check_state_dict_equal
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
|
||||
def test_save_load():
|
||||
with tempfile.TemporaryDirectory() as tempdir:
|
||||
optimizer_state_dict = {
|
||||
0: {"step": torch.tensor(1.0), "exp_avg": torch.rand((1024, 1024)), "exp_avg_sq": torch.rand((1024, 1024))},
|
||||
1: {"step": torch.tensor(1.0), "exp_avg": torch.rand((1024, 1024)), "exp_avg_sq": torch.rand((1024, 1024))},
|
||||
2: {"step": torch.tensor(1.0), "exp_avg": torch.rand((1024, 1024)), "exp_avg_sq": torch.rand((1024, 1024))},
|
||||
}
|
||||
# group_dict = {"param_groups": [0, 1, 2]}
|
||||
group_dict = {
|
||||
"state": {
|
||||
0: {
|
||||
"step": torch.tensor(1.0),
|
||||
"exp_avg": torch.rand((1024, 1024)),
|
||||
"exp_avg_sq": torch.rand((1024, 1024)),
|
||||
},
|
||||
1: {
|
||||
"step": torch.tensor(1.0),
|
||||
"exp_avg": torch.rand((1024, 1024)),
|
||||
"exp_avg_sq": torch.rand((1024, 1024)),
|
||||
},
|
||||
2: {
|
||||
"step": torch.tensor(1.0),
|
||||
"exp_avg": torch.rand((1024, 1024)),
|
||||
"exp_avg_sq": torch.rand((1024, 1024)),
|
||||
},
|
||||
},
|
||||
"param_groups": [
|
||||
{
|
||||
"lr": 0.001,
|
||||
@@ -94,22 +106,26 @@ def test_save_load():
|
||||
61,
|
||||
],
|
||||
}
|
||||
]
|
||||
],
|
||||
}
|
||||
metadata = deepcopy(group_dict)
|
||||
|
||||
optimizer_saved_path = f"{tempdir}/save_optimizer.safetensors"
|
||||
f_writer = AsyncFileWriter(fp=open(optimizer_saved_path, "wb"), n_entries=191, backend="pthread")
|
||||
|
||||
save_nested(f_writer, optimizer_state_dict, metadata)
|
||||
save_nested(f_writer, optimizer_state_dict)
|
||||
f_writer.sync_before_step()
|
||||
f_writer.synchronize()
|
||||
f_writer.fp.close()
|
||||
|
||||
load_state_dict = load_flat(optimizer_saved_path)
|
||||
state_dict = load_state_dict["state"]
|
||||
group = {"param_groups": load_state_dict["param_groups"]}
|
||||
check_state_dict_equal(optimizer_state_dict, state_dict)
|
||||
check_state_dict_equal(group_dict, group)
|
||||
check_state_dict_equal(load_state_dict, optimizer_state_dict)
|
||||
|
||||
optimizer_shard_saved_path = f"{tempdir}/save_optimizer_shard.safetensors"
|
||||
f_writer = AsyncFileWriter(fp=open(optimizer_shard_saved_path, "wb"), n_entries=191, backend="pthread")
|
||||
save_nested(f_writer, optimizer_state_dict["state"])
|
||||
f_writer.sync_before_step()
|
||||
f_writer.synchronize()
|
||||
f_writer.fp.close()
|
||||
load_state_dict_shard = load_flat(optimizer_shard_saved_path)
|
||||
check_state_dict_equal(load_state_dict_shard, optimizer_state_dict["state"])
|
||||
|
||||
model_state_dict = {
|
||||
"module.weight0": torch.rand((1024, 1024)),
|
||||
@@ -118,10 +134,20 @@ def test_save_load():
|
||||
}
|
||||
model_saved_path = f"{tempdir}/save_model.safetensors"
|
||||
f_writer = AsyncFileWriter(fp=open(model_saved_path, "wb"), n_entries=191, backend="pthread")
|
||||
save_nested(f_writer, model_state_dict)
|
||||
save(f_writer, model_state_dict)
|
||||
f_writer.sync_before_step()
|
||||
f_writer.synchronize()
|
||||
f_writer.fp.close()
|
||||
|
||||
load_state_dict = load_flat(model_saved_path)
|
||||
load_state_dict = load_file(model_saved_path)
|
||||
check_state_dict_equal(model_state_dict, load_state_dict)
|
||||
|
||||
model_state_dict_cuda = {k: v.to(get_current_device()) for k, v in model_state_dict.items()}
|
||||
model_state_pinned = {k: v.pin_memory() for k, v in model_state_dict.items()}
|
||||
model_saved_path = f"{tempdir}/save_model_cuda.safetensors"
|
||||
f_writer = AsyncFileWriter(fp=open(model_saved_path, "wb"), n_entries=191, backend="pthread")
|
||||
move_and_save(f_writer, model_state_dict_cuda, model_state_pinned)
|
||||
f_writer.sync_before_step()
|
||||
f_writer.synchronize()
|
||||
f_writer.fp.close()
|
||||
load_state_dict = load_file(model_saved_path)
|
||||
check_state_dict_equal(model_state_dict, load_state_dict)
|
||||
|
Reference in New Issue
Block a user