mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-22 18:09:06 +00:00
[checkpointio] support load-pin overlap (#6177)
* [checkpointio] support load-pin overlap * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [test] add conftest --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -6,7 +6,7 @@ from collections import abc as container_abcs
|
||||
from collections import defaultdict
|
||||
from itertools import chain
|
||||
from pathlib import Path
|
||||
from typing import Dict, Iterator, List, Mapping, Optional, OrderedDict, Tuple, Union
|
||||
from typing import Dict, Generator, Iterator, List, Mapping, Optional, OrderedDict, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -21,7 +21,7 @@ from colossalai.tensor.d_tensor import (
|
||||
to_global,
|
||||
to_global_for_customized_distributed_tensor,
|
||||
)
|
||||
from colossalai.utils.safetensors import _flatten_optim_state_dict
|
||||
from colossalai.utils.safetensors import _flatten_optim_state_dict, load_flat
|
||||
|
||||
SAFE_WEIGHTS_NAME = "model.safetensors"
|
||||
WEIGHTS_NAME = "pytorch_model.bin"
|
||||
@@ -972,3 +972,35 @@ def create_pinned_state_dict(
|
||||
idx = future_to_idx[future]
|
||||
elems[idx] = future.result()
|
||||
return tree_unflatten(elems, spec)
|
||||
|
||||
|
||||
def load_optim_or_model_shard(path: str, is_optim: bool, use_safetensors: bool) -> dict:
|
||||
if is_optim:
|
||||
if path.endswith(".safetensors"):
|
||||
state_dict = load_flat(path)
|
||||
else:
|
||||
state_dict = load_shard_state_dict(Path(path), use_safetensors=False)
|
||||
else:
|
||||
state_dict = load_shard_state_dict(Path(path), use_safetensors)
|
||||
return state_dict
|
||||
|
||||
|
||||
def load_state_dict_shards(
|
||||
checkpoint_files: List[str],
|
||||
is_optim: bool,
|
||||
use_safetensors: bool,
|
||||
low_cpu_mem_mode: bool = True,
|
||||
prefetch: int = 3,
|
||||
) -> Generator[dict, None, None]:
|
||||
if low_cpu_mem_mode:
|
||||
for shard_file in checkpoint_files:
|
||||
state_dict = load_optim_or_model_shard(shard_file, is_optim, use_safetensors)
|
||||
yield state_dict
|
||||
else:
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=prefetch) as executor:
|
||||
futures = []
|
||||
for shard_file in checkpoint_files:
|
||||
future = executor.submit(load_optim_or_model_shard, shard_file, is_optim, use_safetensors)
|
||||
futures.append(future)
|
||||
for future in concurrent.futures.as_completed(futures):
|
||||
yield future.result()
|
||||
|
Reference in New Issue
Block a user