[async io]supoort async io (#6137)

* support async optimizer save/load

* fix

* fix

* support pin mem

* Update low_level_zero_plugin.py

* fix

* fix

* fix

* fix

* fix
This commit is contained in:
flybird11111
2024-11-18 17:52:24 +08:00
committed by Hongxin Liu
parent b90835bd32
commit eb69e640e5
15 changed files with 374 additions and 46 deletions

View File

@@ -51,6 +51,8 @@ def check_low_level_zero_checkpointIO(stage: int, shard: bool, offload: bool, us
model_ckpt_path = f"{model_ckpt_path}.pt"
if not shard and use_async:
model_ckpt_path = f"{model_ckpt_path}.safetensors"
if not shard and use_async:
optimizer_ckpt_path = f"{tempdir}/optimizer.safetensors"
booster.save_model(
model,
model_ckpt_path,
@@ -59,7 +61,7 @@ def check_low_level_zero_checkpointIO(stage: int, shard: bool, offload: bool, us
)
# lr scheduler is tested in test_torch_ddp_checkpoint_io.py and low level zero does not change it, we can skip it here
booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard)
booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard, use_async=use_async)
booster.checkpoint_io._sync_d2h()
booster.checkpoint_io._sync_io()
dist.barrier()
@@ -139,7 +141,6 @@ def run_fn(stage, shard, offload, model_fn, data_gen_fn, output_transform_fn, lo
assert torch.equal(
working_shard, master_param.data.view(-1).to(dtype=padded_param.dtype, device=padded_param.device)
)
new_booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
check_state_dict_equal(optimizer.optim.state_dict(), new_optimizer.optim.state_dict())

View File

@@ -0,0 +1,127 @@
import tempfile
from copy import deepcopy
import torch
from colossalai.utils.safetensors import load_flat, save_nested
try:
from tensornvme.async_file_io import AsyncFileWriter
except ModuleNotFoundError:
raise ModuleNotFoundError("Please install tensornvme to use NVMeOptimizer")
from colossalai.testing import check_state_dict_equal
def test_save_load():
with tempfile.TemporaryDirectory() as tempdir:
optimizer_state_dict = {
0: {"step": torch.tensor(1.0), "exp_avg": torch.rand((1024, 1024)), "exp_avg_sq": torch.rand((1024, 1024))},
1: {"step": torch.tensor(1.0), "exp_avg": torch.rand((1024, 1024)), "exp_avg_sq": torch.rand((1024, 1024))},
2: {"step": torch.tensor(1.0), "exp_avg": torch.rand((1024, 1024)), "exp_avg_sq": torch.rand((1024, 1024))},
}
# group_dict = {"param_groups": [0, 1, 2]}
group_dict = {
"param_groups": [
{
"lr": 0.001,
"betas": (0.9, 0.999),
"eps": 1e-08,
"weight_decay": 0,
"bias_correction": True,
"params": [
0,
1,
2,
3,
4,
5,
6,
7,
8,
9,
10,
11,
12,
13,
14,
15,
16,
17,
18,
19,
20,
21,
22,
23,
24,
25,
26,
27,
28,
29,
30,
31,
32,
33,
34,
35,
36,
37,
38,
39,
40,
41,
42,
43,
44,
45,
46,
47,
48,
49,
50,
51,
52,
53,
54,
55,
56,
57,
58,
59,
60,
61,
],
}
]
}
metadata = deepcopy(group_dict)
optimizer_saved_path = f"{tempdir}/save_optimizer.safetensors"
f_writer = AsyncFileWriter(fp=open(optimizer_saved_path, "wb"), n_entries=191, backend="pthread")
save_nested(f_writer, optimizer_state_dict, metadata)
f_writer.sync_before_step()
f_writer.synchronize()
f_writer.fp.close()
load_state_dict = load_flat(optimizer_saved_path)
state_dict = load_state_dict["state"]
group = {"param_groups": load_state_dict["param_groups"]}
check_state_dict_equal(optimizer_state_dict, state_dict)
check_state_dict_equal(group_dict, group)
model_state_dict = {
"module.weight0": torch.rand((1024, 1024)),
"module.weight1": torch.rand((1024, 1024)),
"module.weight2": torch.rand((1024, 1024)),
}
model_saved_path = f"{tempdir}/save_model.safetensors"
f_writer = AsyncFileWriter(fp=open(model_saved_path, "wb"), n_entries=191, backend="pthread")
save_nested(f_writer, model_state_dict)
f_writer.sync_before_step()
f_writer.synchronize()
f_writer.fp.close()
load_state_dict = load_flat(model_saved_path)
check_state_dict_equal(model_state_dict, load_state_dict)