mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 01:55:12 +00:00
[checkpointio] fix zero optimizer async save memory (#6151)
* [checkpointio] fix zero optimizer async save memory * [checkpointio] fit new tensornvme api * [checkpointio] fit new tensornvme api
This commit is contained in:
@@ -10,6 +10,7 @@ try:
|
||||
except ModuleNotFoundError:
|
||||
raise ModuleNotFoundError("Please install tensornvme to use NVMeOptimizer")
|
||||
|
||||
|
||||
from colossalai.testing import check_state_dict_equal
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
@@ -110,20 +111,20 @@ def test_save_load():
|
||||
}
|
||||
|
||||
optimizer_saved_path = f"{tempdir}/save_optimizer.safetensors"
|
||||
f_writer = AsyncFileWriter(fp=open(optimizer_saved_path, "wb"), n_entries=191, backend="pthread")
|
||||
f_writer = AsyncFileWriter(optimizer_saved_path, n_entries=191, backend="pthread")
|
||||
save_nested(f_writer, optimizer_state_dict)
|
||||
f_writer.sync_before_step()
|
||||
f_writer.synchronize()
|
||||
f_writer.fp.close()
|
||||
del f_writer
|
||||
load_state_dict = load_flat(optimizer_saved_path)
|
||||
check_state_dict_equal(load_state_dict, optimizer_state_dict)
|
||||
|
||||
optimizer_shard_saved_path = f"{tempdir}/save_optimizer_shard.safetensors"
|
||||
f_writer = AsyncFileWriter(fp=open(optimizer_shard_saved_path, "wb"), n_entries=191, backend="pthread")
|
||||
f_writer = AsyncFileWriter(optimizer_shard_saved_path, n_entries=191, backend="pthread")
|
||||
save_nested(f_writer, optimizer_state_dict["state"])
|
||||
f_writer.sync_before_step()
|
||||
f_writer.synchronize()
|
||||
f_writer.fp.close()
|
||||
del f_writer
|
||||
load_state_dict_shard = load_flat(optimizer_shard_saved_path)
|
||||
check_state_dict_equal(load_state_dict_shard, optimizer_state_dict["state"])
|
||||
|
||||
@@ -133,21 +134,21 @@ def test_save_load():
|
||||
"module.weight2": torch.rand((1024, 1024)),
|
||||
}
|
||||
model_saved_path = f"{tempdir}/save_model.safetensors"
|
||||
f_writer = AsyncFileWriter(fp=open(model_saved_path, "wb"), n_entries=191, backend="pthread")
|
||||
f_writer = AsyncFileWriter(model_saved_path, n_entries=191, backend="pthread")
|
||||
save(f_writer, model_state_dict)
|
||||
f_writer.sync_before_step()
|
||||
f_writer.synchronize()
|
||||
f_writer.fp.close()
|
||||
del f_writer
|
||||
load_state_dict = load_file(model_saved_path)
|
||||
check_state_dict_equal(model_state_dict, load_state_dict)
|
||||
|
||||
model_state_dict_cuda = {k: v.to(get_current_device()) for k, v in model_state_dict.items()}
|
||||
model_state_pinned = {k: v.pin_memory() for k, v in model_state_dict.items()}
|
||||
model_saved_path = f"{tempdir}/save_model_cuda.safetensors"
|
||||
f_writer = AsyncFileWriter(fp=open(model_saved_path, "wb"), n_entries=191, backend="pthread")
|
||||
f_writer = AsyncFileWriter(model_saved_path, n_entries=191, backend="pthread")
|
||||
move_and_save(f_writer, model_state_dict_cuda, model_state_pinned)
|
||||
f_writer.sync_before_step()
|
||||
f_writer.synchronize()
|
||||
f_writer.fp.close()
|
||||
del f_writer
|
||||
load_state_dict = load_file(model_saved_path)
|
||||
check_state_dict_equal(model_state_dict, load_state_dict)
|
||||
|
Reference in New Issue
Block a user