mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 02:51:59 +00:00
[checkpointio] support non blocking pin load (#6172)
* [checkpointio] support non blocking pin load * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -85,7 +85,12 @@ class CheckpointIO(ABC):
|
||||
self._sync_io()
|
||||
|
||||
def load_model(
|
||||
self, model: Union[nn.Module, ModelWrapper], checkpoint: str, strict: bool = True
|
||||
self,
|
||||
model: Union[nn.Module, ModelWrapper],
|
||||
checkpoint: str,
|
||||
strict: bool = True,
|
||||
low_cpu_mem_mode: bool = True,
|
||||
num_threads: int = 1,
|
||||
) -> Union[nn.Module, ModelWrapper]:
|
||||
"""
|
||||
Load model from checkpoint.
|
||||
@@ -100,6 +105,8 @@ class CheckpointIO(ABC):
|
||||
Distributed tensors cannot be loaded directly unless gathered offline via our CLI.
|
||||
strict (bool): whether to strictly enforce that the param name in
|
||||
the checkpoint match the keys returned by this module's.
|
||||
low_cpu_mem_mode (bool): whether to load the model in low cpu memory mode. If false, it will use RAM cache to accelerate loading. Default: True.
|
||||
num_threads (int): number of threads to use when loading the model. Only useful when disabling low cpu mem mode. Default: 1.
|
||||
"""
|
||||
# since we only support loaded sharded and unsharded weight format
|
||||
# containing no distributed tensors, dtensor -> full tensor conversion
|
||||
@@ -111,17 +118,25 @@ class CheckpointIO(ABC):
|
||||
origin_model = model
|
||||
|
||||
if index_file_exists:
|
||||
self.load_sharded_model(model, index_file_path, strict)
|
||||
self.load_sharded_model(
|
||||
model, index_file_path, strict, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
|
||||
)
|
||||
else:
|
||||
path = Path(checkpoint, SAFE_WEIGHTS_NAME)
|
||||
if path.is_file():
|
||||
self.load_unsharded_model(model, str(path), strict)
|
||||
self.load_unsharded_model(
|
||||
model, str(path), strict, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
|
||||
)
|
||||
else:
|
||||
path = Path(checkpoint, WEIGHTS_NAME)
|
||||
if path.is_file():
|
||||
self.load_unsharded_model(model, str(path), strict)
|
||||
self.load_unsharded_model(
|
||||
model, str(path), strict, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
|
||||
)
|
||||
else:
|
||||
self.load_unsharded_model(model, checkpoint, strict)
|
||||
self.load_unsharded_model(
|
||||
model, checkpoint, strict, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
|
||||
)
|
||||
|
||||
return origin_model
|
||||
|
||||
@@ -178,7 +193,14 @@ class CheckpointIO(ABC):
|
||||
else:
|
||||
self.save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors, use_async)
|
||||
|
||||
def load_optimizer(self, optimizer: Optimizer, checkpoint: str, prefix: str = None, size_per_shard: int = 1024):
|
||||
def load_optimizer(
|
||||
self,
|
||||
optimizer: Optimizer,
|
||||
checkpoint: str,
|
||||
prefix: str = None,
|
||||
low_cpu_mem_mode: bool = True,
|
||||
num_threads: int = 1,
|
||||
):
|
||||
"""
|
||||
Load optimizer from checkpoint.
|
||||
|
||||
@@ -187,7 +209,8 @@ class CheckpointIO(ABC):
|
||||
checkpoint (str): checkpoint path. This value is made compatibility with the model checkpoints in the
|
||||
prefix (str, optional): A prefix added to parameter and buffer
|
||||
names to compose the keys in state_dict. Defaults to None.
|
||||
size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024.
|
||||
low_cpu_mem_mode (bool): whether to load the model in low cpu memory mode. If false, it will use RAM cache to accelerate loading. Default: True.
|
||||
num_threads (int): number of threads to use when loading the model. Only useful when disabling low cpu mem mode. Default: 1.
|
||||
"""
|
||||
|
||||
index_file_exists, index_file_path = has_index_file(checkpoint)
|
||||
@@ -198,9 +221,13 @@ class CheckpointIO(ABC):
|
||||
|
||||
if index_file_exists:
|
||||
# the existence of index file means it is a sharded checkpoint
|
||||
self.load_sharded_optimizer(optimizer, index_file_path, prefix)
|
||||
self.load_sharded_optimizer(
|
||||
optimizer, index_file_path, prefix, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
|
||||
)
|
||||
else:
|
||||
self.load_unsharded_optimizer(optimizer, checkpoint)
|
||||
self.load_unsharded_optimizer(
|
||||
optimizer, checkpoint, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
|
||||
)
|
||||
|
||||
def save_optimizer(
|
||||
self,
|
||||
@@ -238,7 +265,9 @@ class CheckpointIO(ABC):
|
||||
# Abstract methods for model loading/saving implementation
|
||||
# ========================================================
|
||||
@abstractmethod
|
||||
def load_sharded_model(self, model: nn.Module, index_file_path: str, strict: bool):
|
||||
def load_sharded_model(
|
||||
self, model: nn.Module, index_file_path: str, strict: bool, low_cpu_mem_mode: bool = True, num_threads: int = 1
|
||||
):
|
||||
"""
|
||||
Load model from sharded checkpoint.
|
||||
|
||||
@@ -247,10 +276,14 @@ class CheckpointIO(ABC):
|
||||
index_file_path (str): checkpoint path. It should be path to the .index.json file or a path to a directory which contains a .index.json file.
|
||||
strict (bool): whether to strictly enforce that the param name in
|
||||
the checkpoint match the keys returned by this module's.
|
||||
low_cpu_mem_mode (bool): whether to load the model in low cpu memory mode. If false, it will use RAM cache to accelerate loading. Default: True.
|
||||
num_threads (int): number of threads to use when loading the model. Only useful when disabling low cpu mem mode. Default: 1.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool):
|
||||
def load_unsharded_model(
|
||||
self, model: nn.Module, checkpoint: str, strict: bool, low_cpu_mem_mode: bool = True, num_threads: int = 1
|
||||
):
|
||||
"""
|
||||
Load model from unsharded checkpoint.
|
||||
|
||||
@@ -259,6 +292,8 @@ class CheckpointIO(ABC):
|
||||
checkpoint (str): checkpoint path. It should be a single file path pointing to a model weight binary.
|
||||
strict (bool): whether to strictly enforce that the param name in
|
||||
the checkpoint match the keys returned by this module's.
|
||||
low_cpu_mem_mode (bool): whether to load the model in low cpu memory mode. If false, it will use RAM cache to accelerate loading. Default: True.
|
||||
num_threads (int): number of threads to use when loading the model. Only useful when disabling low cpu mem mode. Default: 1.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
@@ -303,7 +338,14 @@ class CheckpointIO(ABC):
|
||||
# ========================================================
|
||||
|
||||
@abstractmethod
|
||||
def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str):
|
||||
def load_sharded_optimizer(
|
||||
self,
|
||||
optimizer: Optimizer,
|
||||
index_file_path: str,
|
||||
prefix: str,
|
||||
low_cpu_mem_mode: bool = True,
|
||||
num_threads: int = 1,
|
||||
):
|
||||
"""
|
||||
Load optimizer from sharded checkpoint.
|
||||
|
||||
@@ -311,16 +353,22 @@ class CheckpointIO(ABC):
|
||||
optimizer (Optimizer): optimizer to be loaded.
|
||||
index_file_path (str): checkpoint path. It should be path to the .index.json file or a path to a directory which contains a .index.json file.
|
||||
prefix (str): prefix for the optimizer checkpoint.
|
||||
low_cpu_mem_mode (bool): whether to load the model in low cpu memory mode. If false, it will use RAM cache to accelerate loading. Default: True.
|
||||
num_threads (int): number of threads to use when loading the model. Only useful when disabling low cpu mem mode. Default: 1.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path):
|
||||
def load_unsharded_optimizer(
|
||||
self, optimizer: Optimizer, checkpoint: Path, low_cpu_mem_mode: bool = True, num_threads: int = 1
|
||||
):
|
||||
"""
|
||||
Load optimizer from unsharded checkpoint.
|
||||
|
||||
Args:
|
||||
optimizer (Optimizer): optimizer to be loaded.
|
||||
checkpoint (str): checkpoint path. It should be a single file path pointing to a model weight binary.
|
||||
low_cpu_mem_mode (bool): whether to load the model in low cpu memory mode. If false, it will use RAM cache to accelerate loading. Default: True.
|
||||
num_threads (int): number of threads to use when loading the model. Only useful when disabling low cpu mem mode. Default: 1.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
|
@@ -1,4 +1,3 @@
|
||||
import gc
|
||||
import logging
|
||||
import os
|
||||
from functools import reduce
|
||||
@@ -40,8 +39,17 @@ class GeneralCheckpointIO(CheckpointIO):
|
||||
Checkpoint IO
|
||||
"""
|
||||
|
||||
def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool):
|
||||
def load_unsharded_model(
|
||||
self,
|
||||
model: nn.Module,
|
||||
checkpoint: str,
|
||||
strict: bool,
|
||||
low_cpu_mem_mode: bool = True,
|
||||
num_threads: int = 1,
|
||||
):
|
||||
checkpoint = load_state_dict(checkpoint)
|
||||
if not low_cpu_mem_mode:
|
||||
checkpoint = create_pinned_state_dict(checkpoint, empty=False, num_threads=num_threads)
|
||||
model.load_state_dict(checkpoint, strict=strict)
|
||||
|
||||
def save_unsharded_model(
|
||||
@@ -60,7 +68,14 @@ class GeneralCheckpointIO(CheckpointIO):
|
||||
# save the checkpoint
|
||||
save_state_dict(state_dict, checkpoint, use_safetensors)
|
||||
|
||||
def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str):
|
||||
def load_sharded_optimizer(
|
||||
self,
|
||||
optimizer: Optimizer,
|
||||
index_file_path: str,
|
||||
prefix: str,
|
||||
low_cpu_mem_mode: bool = True,
|
||||
num_threads: int = 1,
|
||||
):
|
||||
"""
|
||||
Load sharded optimizer with the given path to index file.
|
||||
"""
|
||||
@@ -84,6 +99,8 @@ class GeneralCheckpointIO(CheckpointIO):
|
||||
state_dict = load_flat(shard_file)
|
||||
else:
|
||||
state_dict = load_shard_state_dict(Path(shard_file), use_safetensors=False)
|
||||
if not low_cpu_mem_mode:
|
||||
state_dict = create_pinned_state_dict(state_dict, empty=False, num_threads=num_threads)
|
||||
load_states_into_optimizer(optimizer, state_dict, id_map)
|
||||
|
||||
sharded_optimizer_loading_epilogue(optimizer)
|
||||
@@ -158,11 +175,15 @@ class GeneralCheckpointIO(CheckpointIO):
|
||||
f"index located at {save_index_file}."
|
||||
)
|
||||
|
||||
def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path):
|
||||
def load_unsharded_optimizer(
|
||||
self, optimizer: Optimizer, checkpoint: Path, low_cpu_mem_mode: bool = True, num_threads: int = 1
|
||||
):
|
||||
if checkpoint.endswith(".safetensors"):
|
||||
checkpoint = load_flat(checkpoint)
|
||||
else:
|
||||
checkpoint = load_state_dict(checkpoint)
|
||||
if not low_cpu_mem_mode:
|
||||
checkpoint = create_pinned_state_dict(checkpoint, empty=False, num_threads=num_threads)
|
||||
optimizer.load_state_dict(checkpoint)
|
||||
|
||||
def save_unsharded_optimizer(
|
||||
@@ -256,6 +277,8 @@ class GeneralCheckpointIO(CheckpointIO):
|
||||
strict: bool = False,
|
||||
use_safetensors: bool = False,
|
||||
load_sub_module: bool = True,
|
||||
low_cpu_mem_mode: bool = True,
|
||||
num_threads: int = 1,
|
||||
):
|
||||
"""
|
||||
load shard model, load model from multiple files
|
||||
@@ -274,9 +297,9 @@ class GeneralCheckpointIO(CheckpointIO):
|
||||
|
||||
for shard_file in checkpoint_files:
|
||||
state_dict = load_shard_state_dict(Path(shard_file), use_safetensors)
|
||||
if not low_cpu_mem_mode:
|
||||
state_dict = create_pinned_state_dict(state_dict, empty=False, num_threads=num_threads)
|
||||
load_state_dict_into_model(model, state_dict, missing_keys, strict, load_sub_module)
|
||||
del state_dict
|
||||
gc.collect()
|
||||
|
||||
if strict:
|
||||
remain_keys = reduce(lambda a, b: a & b, map(set, missing_keys))
|
||||
|
@@ -355,7 +355,14 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
f"index located at {final_index_file_path}."
|
||||
)
|
||||
|
||||
def load_sharded_model(self, model: ModelWrapper, checkpoint_index_file: Path, strict: bool = False):
|
||||
def load_sharded_model(
|
||||
self,
|
||||
model: ModelWrapper,
|
||||
checkpoint_index_file: Path,
|
||||
strict: bool = False,
|
||||
low_cpu_mem_mode: bool = True,
|
||||
num_threads: int = 1,
|
||||
):
|
||||
"""
|
||||
Load sharded model with the given path to index file of checkpoint folder.
|
||||
|
||||
@@ -403,6 +410,8 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
|
||||
file_path = os.path.join(ckpt_root_path, filename)
|
||||
state_dict = load_shard_state_dict(Path(file_path), use_safetensors)
|
||||
if not low_cpu_mem_mode:
|
||||
state_dict = create_pinned_state_dict(state_dict, empty=False, num_threads=num_threads)
|
||||
|
||||
load_state_dict_into_model(
|
||||
model, state_dict, missing_keys=missing_keys, strict=strict, load_sub_module=True
|
||||
@@ -632,7 +641,14 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
f"index located at {final_index_file_path}."
|
||||
)
|
||||
|
||||
def load_sharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint_index_file: str, prefix: str = ""):
|
||||
def load_sharded_optimizer(
|
||||
self,
|
||||
optimizer: OptimizerWrapper,
|
||||
checkpoint_index_file: str,
|
||||
prefix: str = "",
|
||||
low_cpu_mem_mode: bool = True,
|
||||
num_threads: int = 1,
|
||||
):
|
||||
"""
|
||||
Load sharded optimizer with the given path to index file of checkpoint folder.
|
||||
|
||||
@@ -706,6 +722,8 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
state_dict = load_flat(file_path)
|
||||
else:
|
||||
state_dict = load_shard_state_dict(Path(file_path), use_safetensors=False)
|
||||
if not low_cpu_mem_mode:
|
||||
state_dict = create_pinned_state_dict(state_dict, empty=False, num_threads=num_threads)
|
||||
load_states_into_optimizer(optimizer.optim, state_dict, id_map, strict=True)
|
||||
loaded_file.add(filename)
|
||||
|
||||
@@ -789,7 +807,14 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
else:
|
||||
save_state_dict(complete_state_dict, checkpoint, use_safetensors)
|
||||
|
||||
def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool = False):
|
||||
def load_unsharded_model(
|
||||
self,
|
||||
model: ModelWrapper,
|
||||
checkpoint: str,
|
||||
strict: bool = False,
|
||||
low_cpu_mem_mode: bool = True,
|
||||
num_threads: int = 1,
|
||||
):
|
||||
"""
|
||||
Load model from a single file with the given path of checkpoint.
|
||||
|
||||
@@ -812,6 +837,8 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
# has been implemented by _load_from_state_dict method of ParallelModule in Shardformer,
|
||||
# model.load_state_dict can be directly called.
|
||||
state_dict = load_state_dict(checkpoint)
|
||||
if not low_cpu_mem_mode:
|
||||
state_dict = create_pinned_state_dict(state_dict, empty=False, num_threads=num_threads)
|
||||
model.load_state_dict(state_dict, strict=strict)
|
||||
|
||||
# Update master params if mixed-precision training is enabled.
|
||||
@@ -912,7 +939,9 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
else:
|
||||
save_state_dict(state_dict, checkpoint, use_safetensors=False)
|
||||
|
||||
def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str):
|
||||
def load_unsharded_optimizer(
|
||||
self, optimizer: OptimizerWrapper, checkpoint: str, low_cpu_mem_mode: bool = True, num_threads: int = 1
|
||||
):
|
||||
"""
|
||||
Load optimizer from a file with given path.
|
||||
|
||||
@@ -940,6 +969,8 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
state_dict = load_flat(checkpoint)
|
||||
else:
|
||||
state_dict = load_state_dict(checkpoint)
|
||||
if not low_cpu_mem_mode:
|
||||
state_dict = create_pinned_state_dict(state_dict, empty=False, num_threads=num_threads)
|
||||
|
||||
# Load param_groups.
|
||||
updated_groups = []
|
||||
|
@@ -510,7 +510,14 @@ class MoECheckpointIO(HybridParallelCheckpointIO):
|
||||
f"index located at {final_index_file_path}."
|
||||
)
|
||||
|
||||
def load_sharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint_index_file: str, prefix: str = ""):
|
||||
def load_sharded_optimizer(
|
||||
self,
|
||||
optimizer: OptimizerWrapper,
|
||||
checkpoint_index_file: str,
|
||||
prefix: str = "",
|
||||
low_cpu_mem_mode: bool = True,
|
||||
num_threads: int = 1,
|
||||
):
|
||||
"""
|
||||
Load sharded optimizer with the given path to index file of checkpoint folder.
|
||||
|
||||
@@ -795,7 +802,14 @@ class MoECheckpointIO(HybridParallelCheckpointIO):
|
||||
dist.barrier()
|
||||
|
||||
# Copied from colossalai.moe
|
||||
def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, strict: bool = False):
|
||||
def load_unsharded_optimizer(
|
||||
self,
|
||||
optimizer: OptimizerWrapper,
|
||||
checkpoint: str,
|
||||
strict: bool = False,
|
||||
low_cpu_mem_mode: bool = True,
|
||||
num_threads: int = 1,
|
||||
):
|
||||
"""
|
||||
Load optimizer from a file with given path.
|
||||
|
||||
|
@@ -1,18 +1,20 @@
|
||||
# coding=utf-8
|
||||
import concurrent.futures
|
||||
import os
|
||||
import re
|
||||
from collections import abc as container_abcs
|
||||
from collections import defaultdict
|
||||
from itertools import chain
|
||||
from pathlib import Path
|
||||
from typing import Dict, Iterator, List, Mapping, Optional, OrderedDict, Tuple
|
||||
from typing import Dict, Iterator, List, Mapping, Optional, OrderedDict, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from packaging.version import Version
|
||||
from torch.optim import Optimizer
|
||||
from torch.utils._pytree import tree_map
|
||||
from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.tensor.d_tensor import (
|
||||
is_customized_distributed_tensor,
|
||||
is_distributed_tensor,
|
||||
@@ -791,7 +793,7 @@ def load_states_into_optimizer(optimizer: Optimizer, state_dict: dict, id_map: d
|
||||
if key != "step":
|
||||
if param.is_floating_point():
|
||||
value = value.to(param.dtype)
|
||||
value = value.to(param.device)
|
||||
value = value.to(param.device, non_blocking=True)
|
||||
return value
|
||||
elif isinstance(value, dict):
|
||||
return {k: cast(param, v, key=k) for k, v in value.items()}
|
||||
@@ -811,6 +813,7 @@ def load_states_into_optimizer(optimizer: Optimizer, state_dict: dict, id_map: d
|
||||
elif not strict:
|
||||
new_states[k] = v
|
||||
|
||||
get_accelerator().synchronize()
|
||||
optimizer.state.update(new_states)
|
||||
|
||||
|
||||
@@ -945,8 +948,27 @@ def get_shard_filename(weights_name: str, idx: int):
|
||||
return shard_file
|
||||
|
||||
|
||||
def create_pinned_state_dict(state_dict: Dict[str, torch.Tensor]):
|
||||
pin_mem = dict()
|
||||
for name, tensor in state_dict.items():
|
||||
pin_mem[name] = torch.empty_like(tensor, pin_memory=True, device="cpu")
|
||||
return pin_mem
|
||||
def _pin_tensor(tensor: torch.Tensor, empty: bool = True) -> torch.Tensor:
|
||||
if empty:
|
||||
return torch.empty_like(tensor, pin_memory=True, device="cpu")
|
||||
return tensor.pin_memory()
|
||||
|
||||
|
||||
def create_pinned_state_dict(
|
||||
state_dict: Union[Dict[str, torch.Tensor], Dict[int, Dict[str, torch.Tensor]]],
|
||||
empty: bool = True,
|
||||
num_threads: int = 1,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
if num_threads == 1:
|
||||
return tree_map(lambda x: _pin_tensor(x, empty=empty) if isinstance(x, torch.Tensor) else x, state_dict)
|
||||
else:
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor:
|
||||
elems, spec = tree_flatten(state_dict)
|
||||
future_to_idx = {}
|
||||
for i, elem in enumerate(elems):
|
||||
if isinstance(elem, torch.Tensor):
|
||||
future_to_idx[executor.submit(_pin_tensor, elem, empty)] = i
|
||||
for future in concurrent.futures.as_completed(future_to_idx):
|
||||
idx = future_to_idx[future]
|
||||
elems[idx] = future.result()
|
||||
return tree_unflatten(elems, spec)
|
||||
|
Reference in New Issue
Block a user