[checkpointio] support non blocking pin load (#6172)

* [checkpointio] support non blocking pin load

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Hongxin Liu
2024-12-25 17:03:25 +08:00
committed by GitHub
parent 836992438f
commit af06d162cf
15 changed files with 484 additions and 174 deletions

View File

@@ -90,8 +90,16 @@ def exam_state_dict_with_origin(
@parameterize("tp_size", [1, 2])
@parameterize("zero_size", [2])
@parameterize("use_async", [False, True])
@parameterize("low_cpu_mem_mode", [True, False])
def exam_state_dict(
placement_config, shard: bool, model_name: str, size_per_shard: int, tp_size: int, zero_size: int, use_async: bool
placement_config,
shard: bool,
model_name: str,
size_per_shard: int,
tp_size: int,
zero_size: int,
use_async: bool,
low_cpu_mem_mode: bool,
):
(model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values()))
criterion = lambda x: x.mean()
@@ -147,12 +155,12 @@ def exam_state_dict(
booster.checkpoint_io._sync_io()
dist.barrier()
booster.load_model(new_model, model_ckpt_path)
booster.load_model(new_model, model_ckpt_path, low_cpu_mem_mode=low_cpu_mem_mode)
check_state_dict_equal(
model.state_dict(only_rank_0=False), new_model.state_dict(only_rank_0=False), ignore_dtype=True
)
booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
booster.load_optimizer(new_optimizer, optimizer_ckpt_path, low_cpu_mem_mode=low_cpu_mem_mode)
check_state_dict_equal(optimizer.state_dict(only_rank_0=False), new_optimizer.state_dict(only_rank_0=False))
for group in new_optimizer.param_groups:
assert group["lr"] == 0.1

View File

@@ -43,8 +43,11 @@ else:
@parameterize("size_per_shard", [32])
@parameterize("test_config", TEST_CONFIGS)
@parameterize("use_async", [False, True])
@parameterize("low_cpu_mem_mode", [False, True])
@clear_cache_before_run()
def exam_state_dict(shard: bool, model_name: str, size_per_shard: int, test_config: dict, use_async: bool):
def exam_state_dict(
shard: bool, model_name: str, size_per_shard: int, test_config: dict, use_async: bool, low_cpu_mem_mode: bool
):
(model_fn, data_gen_fn, output_transform_fn, loss_fn, _) = next(
iter(model_zoo.get_sub_registry(model_name).values())
)
@@ -102,9 +105,9 @@ def exam_state_dict(shard: bool, model_name: str, size_per_shard: int, test_conf
new_optimizer = Adam(new_model.parameters(), lr=1e-3)
new_model, new_optimizer, criterion, _, _ = booster.boost(new_model, new_optimizer, criterion)
booster.load_model(new_model, model_ckpt_path)
booster.load_model(new_model, model_ckpt_path, low_cpu_mem_mode=low_cpu_mem_mode)
check_state_dict_equal(model.unwrap().state_dict(), new_model.unwrap().state_dict())
booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
booster.load_optimizer(new_optimizer, optimizer_ckpt_path, low_cpu_mem_mode=low_cpu_mem_mode)
check_state_dict_equal(optimizer.unwrap().state_dict(), new_optimizer.unwrap().state_dict())
dist.barrier()

View File

@@ -29,7 +29,8 @@ from tests.kit.model_zoo import model_zoo
@parameterize("shard", [False, True])
@parameterize("offload", [False, True])
@parameterize("use_async", [False, True])
def check_low_level_zero_checkpointIO(stage: int, shard: bool, offload: bool, use_async: bool):
@parameterize("low_cpu_mem_mode", [False, True])
def check_low_level_zero_checkpointIO(stage: int, shard: bool, offload: bool, use_async: bool, low_cpu_mem_mode: bool):
plugin = LowLevelZeroPlugin(stage=stage, max_norm=1.0, initial_scale=32, cpu_offload=offload)
booster = Booster(plugin=plugin)
model = resnet18()
@@ -70,7 +71,7 @@ def check_low_level_zero_checkpointIO(stage: int, shard: bool, offload: bool, us
new_optimizer = HybridAdam((new_model.parameters()), lr=0.001)
new_model, new_optimizer, _, _, _ = booster.boost(new_model, new_optimizer)
booster.load_model(new_model, model_ckpt_path)
booster.load_model(new_model, model_ckpt_path, low_cpu_mem_mode=low_cpu_mem_mode)
check_state_dict_equal(model.state_dict(), new_model.state_dict())
# check master weight
assert isinstance(new_optimizer, LowLevelZeroOptimizer)
@@ -85,7 +86,7 @@ def check_low_level_zero_checkpointIO(stage: int, shard: bool, offload: bool, us
working_shard, master_param.data.view(-1).to(dtype=padded_param.dtype, device=padded_param.device)
)
booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
booster.load_optimizer(new_optimizer, optimizer_ckpt_path, low_cpu_mem_mode=low_cpu_mem_mode)
check_state_dict_equal(optimizer.optim.state_dict(), new_optimizer.optim.state_dict())
torch.cuda.empty_cache()

View File

@@ -1,108 +1,144 @@
import tempfile
import pytest
import torch
from safetensors.torch import load_file
from colossalai.checkpoint_io.utils import create_pinned_state_dict
from colossalai.testing import check_state_dict_equal, clear_cache_before_run
from colossalai.utils import get_current_device
from colossalai.utils.safetensors import load_flat, move_and_save, save, save_nested
def gen_optim_state_dict():
return {
"state": {
0: {
"step": torch.tensor(1.0),
"exp_avg": torch.rand((1024, 1024)),
"exp_avg_sq": torch.rand((1024, 1024)),
},
1: {
"step": torch.tensor(1.0),
"exp_avg": torch.rand((1024, 1024)),
"exp_avg_sq": torch.rand((1024, 1024)),
},
2: {
"step": torch.tensor(1.0),
"exp_avg": torch.rand((1024, 1024)),
"exp_avg_sq": torch.rand((1024, 1024)),
},
},
"param_groups": [
{
"lr": 0.001,
"betas": (0.9, 0.999),
"eps": 1e-08,
"weight_decay": 0,
"bias_correction": True,
"params": [
0,
1,
2,
3,
4,
5,
6,
7,
8,
9,
10,
11,
12,
13,
14,
15,
16,
17,
18,
19,
20,
21,
22,
23,
24,
25,
26,
27,
28,
29,
30,
31,
32,
33,
34,
35,
36,
37,
38,
39,
40,
41,
42,
43,
44,
45,
46,
47,
48,
49,
50,
51,
52,
53,
54,
55,
56,
57,
58,
59,
60,
61,
],
}
],
}
def gen_model_state_dict():
return {
"module.weight0": torch.rand((1024, 1024)),
"module.weight1": torch.rand((1024, 1024)),
"module.weight2": torch.rand((1024, 1024)),
}
@pytest.mark.parametrize("empty", [True, False])
@pytest.mark.parametrize("num_threads", [1, 4])
def test_create_pin(empty: bool, num_threads: int):
model_state_dict = gen_model_state_dict()
model_state_dict_pinned = create_pinned_state_dict(model_state_dict, empty=empty, num_threads=num_threads)
for k in model_state_dict.keys():
assert model_state_dict_pinned[k].is_pinned()
if not empty:
assert torch.equal(model_state_dict_pinned[k], model_state_dict[k])
optim_state_dict = gen_optim_state_dict()
optim_state_dict_pinned = create_pinned_state_dict(optim_state_dict, empty=empty, num_threads=num_threads)
for k in optim_state_dict.keys():
if k == "state":
for idx in optim_state_dict[k].keys():
for kk in optim_state_dict[k][idx].keys():
assert optim_state_dict_pinned[k][idx][kk].is_pinned()
if not empty:
assert torch.equal(optim_state_dict_pinned[k][idx][kk], optim_state_dict[k][idx][kk])
else:
assert optim_state_dict[k] == optim_state_dict_pinned[k]
@clear_cache_before_run()
def test_save_load():
with tempfile.TemporaryDirectory() as tempdir:
optimizer_state_dict = {
"state": {
0: {
"step": torch.tensor(1.0),
"exp_avg": torch.rand((1024, 1024)),
"exp_avg_sq": torch.rand((1024, 1024)),
},
1: {
"step": torch.tensor(1.0),
"exp_avg": torch.rand((1024, 1024)),
"exp_avg_sq": torch.rand((1024, 1024)),
},
2: {
"step": torch.tensor(1.0),
"exp_avg": torch.rand((1024, 1024)),
"exp_avg_sq": torch.rand((1024, 1024)),
},
},
"param_groups": [
{
"lr": 0.001,
"betas": (0.9, 0.999),
"eps": 1e-08,
"weight_decay": 0,
"bias_correction": True,
"params": [
0,
1,
2,
3,
4,
5,
6,
7,
8,
9,
10,
11,
12,
13,
14,
15,
16,
17,
18,
19,
20,
21,
22,
23,
24,
25,
26,
27,
28,
29,
30,
31,
32,
33,
34,
35,
36,
37,
38,
39,
40,
41,
42,
43,
44,
45,
46,
47,
48,
49,
50,
51,
52,
53,
54,
55,
56,
57,
58,
59,
60,
61,
],
}
],
}
optimizer_state_dict = gen_optim_state_dict()
optimizer_saved_path = f"{tempdir}/save_optimizer.safetensors"
f_writer = save_nested(optimizer_saved_path, optimizer_state_dict)
@@ -120,11 +156,7 @@ def test_save_load():
load_state_dict_shard = load_flat(optimizer_shard_saved_path)
check_state_dict_equal(load_state_dict_shard, optimizer_state_dict["state"])
model_state_dict = {
"module.weight0": torch.rand((1024, 1024)),
"module.weight1": torch.rand((1024, 1024)),
"module.weight2": torch.rand((1024, 1024)),
}
model_state_dict = gen_model_state_dict()
model_saved_path = f"{tempdir}/save_model.safetensors"
f_writer = save(model_saved_path, model_state_dict)
f_writer.sync_before_step()

View File

@@ -15,7 +15,8 @@ from colossalai.testing import check_state_dict_equal, parameterize, rerun_if_ad
@parameterize("shard", [False, True])
@parameterize("size_per_shard", [16, 128])
@parameterize("use_async", [False, True])
def check_torch_ddp_checkpointIO(shard: bool, size_per_shard: int, use_async: bool):
@parameterize("low_cpu_mem_mode", [False, True])
def check_torch_ddp_checkpointIO(shard: bool, size_per_shard: int, use_async: bool, low_cpu_mem_mode: bool):
plugin = TorchDDPPlugin()
booster = Booster(plugin=plugin)
model = resnet18()
@@ -61,10 +62,10 @@ def check_torch_ddp_checkpointIO(shard: bool, size_per_shard: int, use_async: bo
new_model, new_optimizer, lr_scheduler=new_scheduler
)
booster.load_model(new_model, model_ckpt_path)
booster.load_model(new_model, model_ckpt_path, low_cpu_mem_mode=low_cpu_mem_mode)
check_state_dict_equal(model.state_dict(), new_model.state_dict())
booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
booster.load_optimizer(new_optimizer, optimizer_ckpt_path, low_cpu_mem_mode=low_cpu_mem_mode)
check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict())
booster.load_lr_scheduler(new_scheduler, lr_scheduler_ckpt_path)
check_state_dict_equal(scheduler.state_dict(), new_scheduler.state_dict())