mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 18:40:28 +00:00
[async io]supoort async io (#6137)
* support async optimizer save/load * fix * fix * support pin mem * Update low_level_zero_plugin.py * fix * fix * fix * fix * fix
This commit is contained in:
committed by
Hongxin Liu
parent
b90835bd32
commit
eb69e640e5
@@ -51,6 +51,8 @@ def check_low_level_zero_checkpointIO(stage: int, shard: bool, offload: bool, us
|
||||
model_ckpt_path = f"{model_ckpt_path}.pt"
|
||||
if not shard and use_async:
|
||||
model_ckpt_path = f"{model_ckpt_path}.safetensors"
|
||||
if not shard and use_async:
|
||||
optimizer_ckpt_path = f"{tempdir}/optimizer.safetensors"
|
||||
booster.save_model(
|
||||
model,
|
||||
model_ckpt_path,
|
||||
@@ -59,7 +61,7 @@ def check_low_level_zero_checkpointIO(stage: int, shard: bool, offload: bool, us
|
||||
)
|
||||
|
||||
# lr scheduler is tested in test_torch_ddp_checkpoint_io.py and low level zero does not change it, we can skip it here
|
||||
booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard)
|
||||
booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard, use_async=use_async)
|
||||
booster.checkpoint_io._sync_d2h()
|
||||
booster.checkpoint_io._sync_io()
|
||||
dist.barrier()
|
||||
@@ -139,7 +141,6 @@ def run_fn(stage, shard, offload, model_fn, data_gen_fn, output_transform_fn, lo
|
||||
assert torch.equal(
|
||||
working_shard, master_param.data.view(-1).to(dtype=padded_param.dtype, device=padded_param.device)
|
||||
)
|
||||
|
||||
new_booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
|
||||
check_state_dict_equal(optimizer.optim.state_dict(), new_optimizer.optim.state_dict())
|
||||
|
||||
|
127
tests/test_checkpoint_io/test_safetensors_async_io.py
Normal file
127
tests/test_checkpoint_io/test_safetensors_async_io.py
Normal file
@@ -0,0 +1,127 @@
|
||||
import tempfile
|
||||
from copy import deepcopy
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.utils.safetensors import load_flat, save_nested
|
||||
|
||||
try:
|
||||
from tensornvme.async_file_io import AsyncFileWriter
|
||||
except ModuleNotFoundError:
|
||||
raise ModuleNotFoundError("Please install tensornvme to use NVMeOptimizer")
|
||||
|
||||
from colossalai.testing import check_state_dict_equal
|
||||
|
||||
|
||||
def test_save_load():
|
||||
with tempfile.TemporaryDirectory() as tempdir:
|
||||
optimizer_state_dict = {
|
||||
0: {"step": torch.tensor(1.0), "exp_avg": torch.rand((1024, 1024)), "exp_avg_sq": torch.rand((1024, 1024))},
|
||||
1: {"step": torch.tensor(1.0), "exp_avg": torch.rand((1024, 1024)), "exp_avg_sq": torch.rand((1024, 1024))},
|
||||
2: {"step": torch.tensor(1.0), "exp_avg": torch.rand((1024, 1024)), "exp_avg_sq": torch.rand((1024, 1024))},
|
||||
}
|
||||
# group_dict = {"param_groups": [0, 1, 2]}
|
||||
group_dict = {
|
||||
"param_groups": [
|
||||
{
|
||||
"lr": 0.001,
|
||||
"betas": (0.9, 0.999),
|
||||
"eps": 1e-08,
|
||||
"weight_decay": 0,
|
||||
"bias_correction": True,
|
||||
"params": [
|
||||
0,
|
||||
1,
|
||||
2,
|
||||
3,
|
||||
4,
|
||||
5,
|
||||
6,
|
||||
7,
|
||||
8,
|
||||
9,
|
||||
10,
|
||||
11,
|
||||
12,
|
||||
13,
|
||||
14,
|
||||
15,
|
||||
16,
|
||||
17,
|
||||
18,
|
||||
19,
|
||||
20,
|
||||
21,
|
||||
22,
|
||||
23,
|
||||
24,
|
||||
25,
|
||||
26,
|
||||
27,
|
||||
28,
|
||||
29,
|
||||
30,
|
||||
31,
|
||||
32,
|
||||
33,
|
||||
34,
|
||||
35,
|
||||
36,
|
||||
37,
|
||||
38,
|
||||
39,
|
||||
40,
|
||||
41,
|
||||
42,
|
||||
43,
|
||||
44,
|
||||
45,
|
||||
46,
|
||||
47,
|
||||
48,
|
||||
49,
|
||||
50,
|
||||
51,
|
||||
52,
|
||||
53,
|
||||
54,
|
||||
55,
|
||||
56,
|
||||
57,
|
||||
58,
|
||||
59,
|
||||
60,
|
||||
61,
|
||||
],
|
||||
}
|
||||
]
|
||||
}
|
||||
metadata = deepcopy(group_dict)
|
||||
optimizer_saved_path = f"{tempdir}/save_optimizer.safetensors"
|
||||
f_writer = AsyncFileWriter(fp=open(optimizer_saved_path, "wb"), n_entries=191, backend="pthread")
|
||||
|
||||
save_nested(f_writer, optimizer_state_dict, metadata)
|
||||
f_writer.sync_before_step()
|
||||
f_writer.synchronize()
|
||||
f_writer.fp.close()
|
||||
|
||||
load_state_dict = load_flat(optimizer_saved_path)
|
||||
state_dict = load_state_dict["state"]
|
||||
group = {"param_groups": load_state_dict["param_groups"]}
|
||||
check_state_dict_equal(optimizer_state_dict, state_dict)
|
||||
check_state_dict_equal(group_dict, group)
|
||||
|
||||
model_state_dict = {
|
||||
"module.weight0": torch.rand((1024, 1024)),
|
||||
"module.weight1": torch.rand((1024, 1024)),
|
||||
"module.weight2": torch.rand((1024, 1024)),
|
||||
}
|
||||
model_saved_path = f"{tempdir}/save_model.safetensors"
|
||||
f_writer = AsyncFileWriter(fp=open(model_saved_path, "wb"), n_entries=191, backend="pthread")
|
||||
save_nested(f_writer, model_state_dict)
|
||||
f_writer.sync_before_step()
|
||||
f_writer.synchronize()
|
||||
f_writer.fp.close()
|
||||
|
||||
load_state_dict = load_flat(model_saved_path)
|
||||
check_state_dict_equal(model_state_dict, load_state_dict)
|
Reference in New Issue
Block a user