mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 20:10:17 +00:00
[checkpointio] support non blocking pin load (#6172)
* [checkpointio] support non blocking pin load * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -355,7 +355,14 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
f"index located at {final_index_file_path}."
|
||||
)
|
||||
|
||||
def load_sharded_model(self, model: ModelWrapper, checkpoint_index_file: Path, strict: bool = False):
|
||||
def load_sharded_model(
|
||||
self,
|
||||
model: ModelWrapper,
|
||||
checkpoint_index_file: Path,
|
||||
strict: bool = False,
|
||||
low_cpu_mem_mode: bool = True,
|
||||
num_threads: int = 1,
|
||||
):
|
||||
"""
|
||||
Load sharded model with the given path to index file of checkpoint folder.
|
||||
|
||||
@@ -403,6 +410,8 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
|
||||
file_path = os.path.join(ckpt_root_path, filename)
|
||||
state_dict = load_shard_state_dict(Path(file_path), use_safetensors)
|
||||
if not low_cpu_mem_mode:
|
||||
state_dict = create_pinned_state_dict(state_dict, empty=False, num_threads=num_threads)
|
||||
|
||||
load_state_dict_into_model(
|
||||
model, state_dict, missing_keys=missing_keys, strict=strict, load_sub_module=True
|
||||
@@ -632,7 +641,14 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
f"index located at {final_index_file_path}."
|
||||
)
|
||||
|
||||
def load_sharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint_index_file: str, prefix: str = ""):
|
||||
def load_sharded_optimizer(
|
||||
self,
|
||||
optimizer: OptimizerWrapper,
|
||||
checkpoint_index_file: str,
|
||||
prefix: str = "",
|
||||
low_cpu_mem_mode: bool = True,
|
||||
num_threads: int = 1,
|
||||
):
|
||||
"""
|
||||
Load sharded optimizer with the given path to index file of checkpoint folder.
|
||||
|
||||
@@ -706,6 +722,8 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
state_dict = load_flat(file_path)
|
||||
else:
|
||||
state_dict = load_shard_state_dict(Path(file_path), use_safetensors=False)
|
||||
if not low_cpu_mem_mode:
|
||||
state_dict = create_pinned_state_dict(state_dict, empty=False, num_threads=num_threads)
|
||||
load_states_into_optimizer(optimizer.optim, state_dict, id_map, strict=True)
|
||||
loaded_file.add(filename)
|
||||
|
||||
@@ -789,7 +807,14 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
else:
|
||||
save_state_dict(complete_state_dict, checkpoint, use_safetensors)
|
||||
|
||||
def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool = False):
|
||||
def load_unsharded_model(
|
||||
self,
|
||||
model: ModelWrapper,
|
||||
checkpoint: str,
|
||||
strict: bool = False,
|
||||
low_cpu_mem_mode: bool = True,
|
||||
num_threads: int = 1,
|
||||
):
|
||||
"""
|
||||
Load model from a single file with the given path of checkpoint.
|
||||
|
||||
@@ -812,6 +837,8 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
# has been implemented by _load_from_state_dict method of ParallelModule in Shardformer,
|
||||
# model.load_state_dict can be directly called.
|
||||
state_dict = load_state_dict(checkpoint)
|
||||
if not low_cpu_mem_mode:
|
||||
state_dict = create_pinned_state_dict(state_dict, empty=False, num_threads=num_threads)
|
||||
model.load_state_dict(state_dict, strict=strict)
|
||||
|
||||
# Update master params if mixed-precision training is enabled.
|
||||
@@ -912,7 +939,9 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
else:
|
||||
save_state_dict(state_dict, checkpoint, use_safetensors=False)
|
||||
|
||||
def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str):
|
||||
def load_unsharded_optimizer(
|
||||
self, optimizer: OptimizerWrapper, checkpoint: str, low_cpu_mem_mode: bool = True, num_threads: int = 1
|
||||
):
|
||||
"""
|
||||
Load optimizer from a file with given path.
|
||||
|
||||
@@ -940,6 +969,8 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
state_dict = load_flat(checkpoint)
|
||||
else:
|
||||
state_dict = load_state_dict(checkpoint)
|
||||
if not low_cpu_mem_mode:
|
||||
state_dict = create_pinned_state_dict(state_dict, empty=False, num_threads=num_threads)
|
||||
|
||||
# Load param_groups.
|
||||
updated_groups = []
|
||||
|
Reference in New Issue
Block a user