mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 03:20:52 +00:00
[checkpointio] support non blocking pin load (#6172)
* [checkpointio] support non blocking pin load * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -20,6 +20,7 @@ from torch.utils.data import DataLoader
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO
|
||||
from colossalai.checkpoint_io.utils import (
|
||||
create_pinned_state_dict,
|
||||
get_optimizer_base_filenames,
|
||||
get_shard_filename,
|
||||
load_param_groups_into_optimizer,
|
||||
@@ -145,7 +146,9 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
||||
else:
|
||||
save_state_dict(state_dict, checkpoint, use_safetensors=False)
|
||||
|
||||
def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str):
|
||||
def load_unsharded_optimizer(
|
||||
self, optimizer: OptimizerWrapper, checkpoint: str, low_cpu_mem_mode: bool = True, num_threads: int = 1
|
||||
):
|
||||
use_async = checkpoint.endswith(".safetensors")
|
||||
if use_async:
|
||||
from colossalai.utils.safetensors import load_flat
|
||||
@@ -153,6 +156,8 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
||||
checkpoint = load_flat(checkpoint)
|
||||
else:
|
||||
checkpoint = load_state_dict(checkpoint)
|
||||
if not low_cpu_mem_mode:
|
||||
checkpoint = create_pinned_state_dict(checkpoint, empty=False, num_threads=num_threads)
|
||||
optimizer.load_state_dict(checkpoint)
|
||||
|
||||
def save_sharded_optimizer(
|
||||
@@ -239,7 +244,14 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
||||
ranks=[0],
|
||||
)
|
||||
|
||||
def load_sharded_optimizer(self, optimizer: OptimizerWrapper, index_file_path: str, prefix: str):
|
||||
def load_sharded_optimizer(
|
||||
self,
|
||||
optimizer: OptimizerWrapper,
|
||||
index_file_path: str,
|
||||
prefix: str,
|
||||
low_cpu_mem_mode: bool = True,
|
||||
num_threads: int = 1,
|
||||
):
|
||||
"""Load sharded optimizer with the given path to index file.
|
||||
|
||||
Args:
|
||||
@@ -283,14 +295,28 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
||||
if padding_size > 0:
|
||||
v = torch.nn.functional.pad(v, [0, padding_size])
|
||||
v_list = v.split(v.numel() // self.coordinator.world_size)
|
||||
state_dict[param_idx][k] = v_list[self.coordinator.rank].detach().clone()
|
||||
state_dict[param_idx][k] = v_list[self.coordinator.rank].detach()
|
||||
if low_cpu_mem_mode:
|
||||
state_dict[param_idx][k] = state_dict[param_idx][k].clone()
|
||||
|
||||
if not low_cpu_mem_mode:
|
||||
state_dict = create_pinned_state_dict(state_dict, empty=False, num_threads=num_threads)
|
||||
load_states_into_optimizer(optimizer, state_dict, id_map)
|
||||
sharded_optimizer_loading_epilogue(optimizer)
|
||||
|
||||
def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool = True):
|
||||
def load_unsharded_model(
|
||||
self,
|
||||
model: ModelWrapper,
|
||||
checkpoint: str,
|
||||
strict: bool = True,
|
||||
low_cpu_mem_mode: bool = True,
|
||||
num_threads: int = 1,
|
||||
):
|
||||
assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!"
|
||||
model._force_wait_all_gather()
|
||||
super().load_unsharded_model(model, checkpoint, strict)
|
||||
super().load_unsharded_model(
|
||||
model, checkpoint, strict, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
|
||||
)
|
||||
model.update_master_params()
|
||||
|
||||
def load_sharded_model(
|
||||
@@ -300,10 +326,20 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
||||
strict: bool = False,
|
||||
use_safetensors: bool = False,
|
||||
load_sub_module: bool = True,
|
||||
low_cpu_mem_mode: bool = True,
|
||||
num_threads: int = 1,
|
||||
):
|
||||
assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!"
|
||||
model._force_wait_all_gather()
|
||||
super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module)
|
||||
super().load_sharded_model(
|
||||
model,
|
||||
checkpoint_index_file,
|
||||
strict,
|
||||
use_safetensors,
|
||||
load_sub_module,
|
||||
low_cpu_mem_mode=low_cpu_mem_mode,
|
||||
num_threads=num_threads,
|
||||
)
|
||||
model.update_master_params()
|
||||
|
||||
def save_unsharded_model(
|
||||
|
Reference in New Issue
Block a user