mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-20 09:01:06 +00:00
[checkpointio] support non blocking pin load (#6172)
* [checkpointio] support non blocking pin load * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -1,4 +1,3 @@
|
||||
import gc
|
||||
import logging
|
||||
import os
|
||||
from functools import reduce
|
||||
@@ -40,8 +39,17 @@ class GeneralCheckpointIO(CheckpointIO):
|
||||
Checkpoint IO
|
||||
"""
|
||||
|
||||
def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool):
|
||||
def load_unsharded_model(
|
||||
self,
|
||||
model: nn.Module,
|
||||
checkpoint: str,
|
||||
strict: bool,
|
||||
low_cpu_mem_mode: bool = True,
|
||||
num_threads: int = 1,
|
||||
):
|
||||
checkpoint = load_state_dict(checkpoint)
|
||||
if not low_cpu_mem_mode:
|
||||
checkpoint = create_pinned_state_dict(checkpoint, empty=False, num_threads=num_threads)
|
||||
model.load_state_dict(checkpoint, strict=strict)
|
||||
|
||||
def save_unsharded_model(
|
||||
@@ -60,7 +68,14 @@ class GeneralCheckpointIO(CheckpointIO):
|
||||
# save the checkpoint
|
||||
save_state_dict(state_dict, checkpoint, use_safetensors)
|
||||
|
||||
def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str):
|
||||
def load_sharded_optimizer(
|
||||
self,
|
||||
optimizer: Optimizer,
|
||||
index_file_path: str,
|
||||
prefix: str,
|
||||
low_cpu_mem_mode: bool = True,
|
||||
num_threads: int = 1,
|
||||
):
|
||||
"""
|
||||
Load sharded optimizer with the given path to index file.
|
||||
"""
|
||||
@@ -84,6 +99,8 @@ class GeneralCheckpointIO(CheckpointIO):
|
||||
state_dict = load_flat(shard_file)
|
||||
else:
|
||||
state_dict = load_shard_state_dict(Path(shard_file), use_safetensors=False)
|
||||
if not low_cpu_mem_mode:
|
||||
state_dict = create_pinned_state_dict(state_dict, empty=False, num_threads=num_threads)
|
||||
load_states_into_optimizer(optimizer, state_dict, id_map)
|
||||
|
||||
sharded_optimizer_loading_epilogue(optimizer)
|
||||
@@ -158,11 +175,15 @@ class GeneralCheckpointIO(CheckpointIO):
|
||||
f"index located at {save_index_file}."
|
||||
)
|
||||
|
||||
def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path):
|
||||
def load_unsharded_optimizer(
|
||||
self, optimizer: Optimizer, checkpoint: Path, low_cpu_mem_mode: bool = True, num_threads: int = 1
|
||||
):
|
||||
if checkpoint.endswith(".safetensors"):
|
||||
checkpoint = load_flat(checkpoint)
|
||||
else:
|
||||
checkpoint = load_state_dict(checkpoint)
|
||||
if not low_cpu_mem_mode:
|
||||
checkpoint = create_pinned_state_dict(checkpoint, empty=False, num_threads=num_threads)
|
||||
optimizer.load_state_dict(checkpoint)
|
||||
|
||||
def save_unsharded_optimizer(
|
||||
@@ -256,6 +277,8 @@ class GeneralCheckpointIO(CheckpointIO):
|
||||
strict: bool = False,
|
||||
use_safetensors: bool = False,
|
||||
load_sub_module: bool = True,
|
||||
low_cpu_mem_mode: bool = True,
|
||||
num_threads: int = 1,
|
||||
):
|
||||
"""
|
||||
load shard model, load model from multiple files
|
||||
@@ -274,9 +297,9 @@ class GeneralCheckpointIO(CheckpointIO):
|
||||
|
||||
for shard_file in checkpoint_files:
|
||||
state_dict = load_shard_state_dict(Path(shard_file), use_safetensors)
|
||||
if not low_cpu_mem_mode:
|
||||
state_dict = create_pinned_state_dict(state_dict, empty=False, num_threads=num_threads)
|
||||
load_state_dict_into_model(model, state_dict, missing_keys, strict, load_sub_module)
|
||||
del state_dict
|
||||
gc.collect()
|
||||
|
||||
if strict:
|
||||
remain_keys = reduce(lambda a, b: a & b, map(set, missing_keys))
|
||||
|
Reference in New Issue
Block a user