[checkpointio] support non blocking pin load (#6172)

* [checkpointio] support non blocking pin load

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Hongxin Liu
2024-12-25 17:03:25 +08:00
committed by GitHub
parent 836992438f
commit af06d162cf
15 changed files with 484 additions and 174 deletions

View File

@@ -90,8 +90,16 @@ def exam_state_dict_with_origin(
@parameterize("tp_size", [1, 2])
@parameterize("zero_size", [2])
@parameterize("use_async", [False, True])
@parameterize("low_cpu_mem_mode", [True, False])
def exam_state_dict(
placement_config, shard: bool, model_name: str, size_per_shard: int, tp_size: int, zero_size: int, use_async: bool
placement_config,
shard: bool,
model_name: str,
size_per_shard: int,
tp_size: int,
zero_size: int,
use_async: bool,
low_cpu_mem_mode: bool,
):
(model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values()))
criterion = lambda x: x.mean()
@@ -147,12 +155,12 @@ def exam_state_dict(
booster.checkpoint_io._sync_io()
dist.barrier()
booster.load_model(new_model, model_ckpt_path)
booster.load_model(new_model, model_ckpt_path, low_cpu_mem_mode=low_cpu_mem_mode)
check_state_dict_equal(
model.state_dict(only_rank_0=False), new_model.state_dict(only_rank_0=False), ignore_dtype=True
)
booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
booster.load_optimizer(new_optimizer, optimizer_ckpt_path, low_cpu_mem_mode=low_cpu_mem_mode)
check_state_dict_equal(optimizer.state_dict(only_rank_0=False), new_optimizer.state_dict(only_rank_0=False))
for group in new_optimizer.param_groups:
assert group["lr"] == 0.1