mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-10 21:40:02 +00:00
[checkpointio] support debug log (#6153)
* [checkpointio] support debug log * [checkpointio] refactor async writer api * fix test * fix test
This commit is contained in:
@@ -3,18 +3,12 @@ import tempfile
|
||||
import torch
|
||||
from safetensors.torch import load_file
|
||||
|
||||
from colossalai.testing import check_state_dict_equal, clear_cache_before_run
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.utils.safetensors import load_flat, move_and_save, save, save_nested
|
||||
|
||||
try:
|
||||
from tensornvme.async_file_io import AsyncFileWriter
|
||||
except ModuleNotFoundError:
|
||||
raise ModuleNotFoundError("Please install tensornvme to use NVMeOptimizer")
|
||||
|
||||
|
||||
from colossalai.testing import check_state_dict_equal
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
|
||||
@clear_cache_before_run()
|
||||
def test_save_load():
|
||||
with tempfile.TemporaryDirectory() as tempdir:
|
||||
optimizer_state_dict = {
|
||||
@@ -111,8 +105,7 @@ def test_save_load():
|
||||
}
|
||||
|
||||
optimizer_saved_path = f"{tempdir}/save_optimizer.safetensors"
|
||||
f_writer = AsyncFileWriter(optimizer_saved_path, n_entries=191, backend="pthread")
|
||||
save_nested(f_writer, optimizer_state_dict)
|
||||
f_writer = save_nested(optimizer_saved_path, optimizer_state_dict)
|
||||
f_writer.sync_before_step()
|
||||
f_writer.synchronize()
|
||||
del f_writer
|
||||
@@ -120,8 +113,7 @@ def test_save_load():
|
||||
check_state_dict_equal(load_state_dict, optimizer_state_dict)
|
||||
|
||||
optimizer_shard_saved_path = f"{tempdir}/save_optimizer_shard.safetensors"
|
||||
f_writer = AsyncFileWriter(optimizer_shard_saved_path, n_entries=191, backend="pthread")
|
||||
save_nested(f_writer, optimizer_state_dict["state"])
|
||||
f_writer = save_nested(optimizer_shard_saved_path, optimizer_state_dict["state"])
|
||||
f_writer.sync_before_step()
|
||||
f_writer.synchronize()
|
||||
del f_writer
|
||||
@@ -134,8 +126,7 @@ def test_save_load():
|
||||
"module.weight2": torch.rand((1024, 1024)),
|
||||
}
|
||||
model_saved_path = f"{tempdir}/save_model.safetensors"
|
||||
f_writer = AsyncFileWriter(model_saved_path, n_entries=191, backend="pthread")
|
||||
save(f_writer, model_state_dict)
|
||||
f_writer = save(model_saved_path, model_state_dict)
|
||||
f_writer.sync_before_step()
|
||||
f_writer.synchronize()
|
||||
del f_writer
|
||||
@@ -145,8 +136,7 @@ def test_save_load():
|
||||
model_state_dict_cuda = {k: v.to(get_current_device()) for k, v in model_state_dict.items()}
|
||||
model_state_pinned = {k: v.pin_memory() for k, v in model_state_dict.items()}
|
||||
model_saved_path = f"{tempdir}/save_model_cuda.safetensors"
|
||||
f_writer = AsyncFileWriter(model_saved_path, n_entries=191, backend="pthread")
|
||||
move_and_save(f_writer, model_state_dict_cuda, model_state_pinned)
|
||||
f_writer = move_and_save(model_saved_path, model_state_dict_cuda, model_state_pinned)
|
||||
f_writer.sync_before_step()
|
||||
f_writer.synchronize()
|
||||
del f_writer
|
||||
|
Reference in New Issue
Block a user