[checkpointio] support load-pin overlap (#6177)

* [checkpointio] support load-pin overlap

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [test] add conftest

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Hongxin Liu
2025-01-07 16:16:04 +08:00
committed by GitHub
parent 479067e9bc
commit ee81366cac
6 changed files with 56 additions and 32 deletions

View File

@@ -6,7 +6,7 @@ from collections import abc as container_abcs
from collections import defaultdict
from itertools import chain
from pathlib import Path
from typing import Dict, Iterator, List, Mapping, Optional, OrderedDict, Tuple, Union
from typing import Dict, Generator, Iterator, List, Mapping, Optional, OrderedDict, Tuple, Union
import torch
import torch.nn as nn
@@ -21,7 +21,7 @@ from colossalai.tensor.d_tensor import (
to_global,
to_global_for_customized_distributed_tensor,
)
from colossalai.utils.safetensors import _flatten_optim_state_dict
from colossalai.utils.safetensors import _flatten_optim_state_dict, load_flat
SAFE_WEIGHTS_NAME = "model.safetensors"
WEIGHTS_NAME = "pytorch_model.bin"
@@ -972,3 +972,35 @@ def create_pinned_state_dict(
idx = future_to_idx[future]
elems[idx] = future.result()
return tree_unflatten(elems, spec)
def load_optim_or_model_shard(path: str, is_optim: bool, use_safetensors: bool) -> dict:
if is_optim:
if path.endswith(".safetensors"):
state_dict = load_flat(path)
else:
state_dict = load_shard_state_dict(Path(path), use_safetensors=False)
else:
state_dict = load_shard_state_dict(Path(path), use_safetensors)
return state_dict
def load_state_dict_shards(
checkpoint_files: List[str],
is_optim: bool,
use_safetensors: bool,
low_cpu_mem_mode: bool = True,
prefetch: int = 3,
) -> Generator[dict, None, None]:
if low_cpu_mem_mode:
for shard_file in checkpoint_files:
state_dict = load_optim_or_model_shard(shard_file, is_optim, use_safetensors)
yield state_dict
else:
with concurrent.futures.ThreadPoolExecutor(max_workers=prefetch) as executor:
futures = []
for shard_file in checkpoint_files:
future = executor.submit(load_optim_or_model_shard, shard_file, is_optim, use_safetensors)
futures.append(future)
for future in concurrent.futures.as_completed(futures):
yield future.result()