mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 17:46:42 +00:00
[checkpointio] support non blocking pin load (#6172)
* [checkpointio] support non blocking pin load * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -43,13 +43,17 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
|
||||
self.coordinator = DistCoordinator()
|
||||
self.logger = get_dist_logger()
|
||||
|
||||
def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool):
|
||||
def load_unsharded_model(
|
||||
self, model: ModelWrapper, checkpoint: str, strict: bool, low_cpu_mem_mode: bool = True, num_threads: int = 1
|
||||
):
|
||||
assert isinstance(model, TorchFSDPModel), "Please boost the model before loading!"
|
||||
model = model.unwrap()
|
||||
checkpoint = utils.load_state_dict(checkpoint)
|
||||
model.load_state_dict(checkpoint)
|
||||
|
||||
def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: Path):
|
||||
def load_unsharded_optimizer(
|
||||
self, optimizer: OptimizerWrapper, checkpoint: Path, low_cpu_mem_mode: bool = True, num_threads: int = 1
|
||||
):
|
||||
assert isinstance(optimizer, FSDPOptimizerWrapper), "Please boost the optimizer before loading!"
|
||||
if checkpoint.endswith(".safetensors"):
|
||||
checkpoint = load_flat(checkpoint, seperator=".")
|
||||
@@ -232,6 +236,8 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
|
||||
strict: bool = False,
|
||||
use_safetensors: bool = False,
|
||||
load_sub_module: bool = True,
|
||||
low_cpu_mem_mode: bool = True,
|
||||
num_threads: int = 1,
|
||||
):
|
||||
"""
|
||||
Load model to checkpoint but only on master process.
|
||||
@@ -354,7 +360,14 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
|
||||
f"index located at {save_index_file}."
|
||||
)
|
||||
|
||||
def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, size_per_shard: int):
|
||||
def load_sharded_optimizer(
|
||||
self,
|
||||
optimizer: Optimizer,
|
||||
index_file_path: str,
|
||||
size_per_shard: int,
|
||||
low_cpu_mem_mode: bool = True,
|
||||
num_threads: int = 1,
|
||||
):
|
||||
"""
|
||||
Load optimizer to checkpoint but only on master process.
|
||||
"""
|
||||
|
Reference in New Issue
Block a user