mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-11 05:49:55 +00:00
[checkpointio] support non blocking pin load (#6172)
* [checkpointio] support non blocking pin load * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -90,8 +90,16 @@ def exam_state_dict_with_origin(
|
||||
@parameterize("tp_size", [1, 2])
|
||||
@parameterize("zero_size", [2])
|
||||
@parameterize("use_async", [False, True])
|
||||
@parameterize("low_cpu_mem_mode", [True, False])
|
||||
def exam_state_dict(
|
||||
placement_config, shard: bool, model_name: str, size_per_shard: int, tp_size: int, zero_size: int, use_async: bool
|
||||
placement_config,
|
||||
shard: bool,
|
||||
model_name: str,
|
||||
size_per_shard: int,
|
||||
tp_size: int,
|
||||
zero_size: int,
|
||||
use_async: bool,
|
||||
low_cpu_mem_mode: bool,
|
||||
):
|
||||
(model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values()))
|
||||
criterion = lambda x: x.mean()
|
||||
@@ -147,12 +155,12 @@ def exam_state_dict(
|
||||
booster.checkpoint_io._sync_io()
|
||||
dist.barrier()
|
||||
|
||||
booster.load_model(new_model, model_ckpt_path)
|
||||
booster.load_model(new_model, model_ckpt_path, low_cpu_mem_mode=low_cpu_mem_mode)
|
||||
check_state_dict_equal(
|
||||
model.state_dict(only_rank_0=False), new_model.state_dict(only_rank_0=False), ignore_dtype=True
|
||||
)
|
||||
|
||||
booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
|
||||
booster.load_optimizer(new_optimizer, optimizer_ckpt_path, low_cpu_mem_mode=low_cpu_mem_mode)
|
||||
check_state_dict_equal(optimizer.state_dict(only_rank_0=False), new_optimizer.state_dict(only_rank_0=False))
|
||||
for group in new_optimizer.param_groups:
|
||||
assert group["lr"] == 0.1
|
||||
|
Reference in New Issue
Block a user