mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 11:02:05 +00:00
[checkpointio] support non blocking pin load (#6172)
* [checkpointio] support non blocking pin load * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -85,7 +85,12 @@ class CheckpointIO(ABC):
|
||||
self._sync_io()
|
||||
|
||||
def load_model(
|
||||
self, model: Union[nn.Module, ModelWrapper], checkpoint: str, strict: bool = True
|
||||
self,
|
||||
model: Union[nn.Module, ModelWrapper],
|
||||
checkpoint: str,
|
||||
strict: bool = True,
|
||||
low_cpu_mem_mode: bool = True,
|
||||
num_threads: int = 1,
|
||||
) -> Union[nn.Module, ModelWrapper]:
|
||||
"""
|
||||
Load model from checkpoint.
|
||||
@@ -100,6 +105,8 @@ class CheckpointIO(ABC):
|
||||
Distributed tensors cannot be loaded directly unless gathered offline via our CLI.
|
||||
strict (bool): whether to strictly enforce that the param name in
|
||||
the checkpoint match the keys returned by this module's.
|
||||
low_cpu_mem_mode (bool): whether to load the model in low cpu memory mode. If false, it will use RAM cache to accelerate loading. Default: True.
|
||||
num_threads (int): number of threads to use when loading the model. Only useful when disabling low cpu mem mode. Default: 1.
|
||||
"""
|
||||
# since we only support loaded sharded and unsharded weight format
|
||||
# containing no distributed tensors, dtensor -> full tensor conversion
|
||||
@@ -111,17 +118,25 @@ class CheckpointIO(ABC):
|
||||
origin_model = model
|
||||
|
||||
if index_file_exists:
|
||||
self.load_sharded_model(model, index_file_path, strict)
|
||||
self.load_sharded_model(
|
||||
model, index_file_path, strict, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
|
||||
)
|
||||
else:
|
||||
path = Path(checkpoint, SAFE_WEIGHTS_NAME)
|
||||
if path.is_file():
|
||||
self.load_unsharded_model(model, str(path), strict)
|
||||
self.load_unsharded_model(
|
||||
model, str(path), strict, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
|
||||
)
|
||||
else:
|
||||
path = Path(checkpoint, WEIGHTS_NAME)
|
||||
if path.is_file():
|
||||
self.load_unsharded_model(model, str(path), strict)
|
||||
self.load_unsharded_model(
|
||||
model, str(path), strict, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
|
||||
)
|
||||
else:
|
||||
self.load_unsharded_model(model, checkpoint, strict)
|
||||
self.load_unsharded_model(
|
||||
model, checkpoint, strict, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
|
||||
)
|
||||
|
||||
return origin_model
|
||||
|
||||
@@ -178,7 +193,14 @@ class CheckpointIO(ABC):
|
||||
else:
|
||||
self.save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors, use_async)
|
||||
|
||||
def load_optimizer(self, optimizer: Optimizer, checkpoint: str, prefix: str = None, size_per_shard: int = 1024):
|
||||
def load_optimizer(
|
||||
self,
|
||||
optimizer: Optimizer,
|
||||
checkpoint: str,
|
||||
prefix: str = None,
|
||||
low_cpu_mem_mode: bool = True,
|
||||
num_threads: int = 1,
|
||||
):
|
||||
"""
|
||||
Load optimizer from checkpoint.
|
||||
|
||||
@@ -187,7 +209,8 @@ class CheckpointIO(ABC):
|
||||
checkpoint (str): checkpoint path. This value is made compatibility with the model checkpoints in the
|
||||
prefix (str, optional): A prefix added to parameter and buffer
|
||||
names to compose the keys in state_dict. Defaults to None.
|
||||
size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024.
|
||||
low_cpu_mem_mode (bool): whether to load the model in low cpu memory mode. If false, it will use RAM cache to accelerate loading. Default: True.
|
||||
num_threads (int): number of threads to use when loading the model. Only useful when disabling low cpu mem mode. Default: 1.
|
||||
"""
|
||||
|
||||
index_file_exists, index_file_path = has_index_file(checkpoint)
|
||||
@@ -198,9 +221,13 @@ class CheckpointIO(ABC):
|
||||
|
||||
if index_file_exists:
|
||||
# the existence of index file means it is a sharded checkpoint
|
||||
self.load_sharded_optimizer(optimizer, index_file_path, prefix)
|
||||
self.load_sharded_optimizer(
|
||||
optimizer, index_file_path, prefix, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
|
||||
)
|
||||
else:
|
||||
self.load_unsharded_optimizer(optimizer, checkpoint)
|
||||
self.load_unsharded_optimizer(
|
||||
optimizer, checkpoint, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
|
||||
)
|
||||
|
||||
def save_optimizer(
|
||||
self,
|
||||
@@ -238,7 +265,9 @@ class CheckpointIO(ABC):
|
||||
# Abstract methods for model loading/saving implementation
|
||||
# ========================================================
|
||||
@abstractmethod
|
||||
def load_sharded_model(self, model: nn.Module, index_file_path: str, strict: bool):
|
||||
def load_sharded_model(
|
||||
self, model: nn.Module, index_file_path: str, strict: bool, low_cpu_mem_mode: bool = True, num_threads: int = 1
|
||||
):
|
||||
"""
|
||||
Load model from sharded checkpoint.
|
||||
|
||||
@@ -247,10 +276,14 @@ class CheckpointIO(ABC):
|
||||
index_file_path (str): checkpoint path. It should be path to the .index.json file or a path to a directory which contains a .index.json file.
|
||||
strict (bool): whether to strictly enforce that the param name in
|
||||
the checkpoint match the keys returned by this module's.
|
||||
low_cpu_mem_mode (bool): whether to load the model in low cpu memory mode. If false, it will use RAM cache to accelerate loading. Default: True.
|
||||
num_threads (int): number of threads to use when loading the model. Only useful when disabling low cpu mem mode. Default: 1.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool):
|
||||
def load_unsharded_model(
|
||||
self, model: nn.Module, checkpoint: str, strict: bool, low_cpu_mem_mode: bool = True, num_threads: int = 1
|
||||
):
|
||||
"""
|
||||
Load model from unsharded checkpoint.
|
||||
|
||||
@@ -259,6 +292,8 @@ class CheckpointIO(ABC):
|
||||
checkpoint (str): checkpoint path. It should be a single file path pointing to a model weight binary.
|
||||
strict (bool): whether to strictly enforce that the param name in
|
||||
the checkpoint match the keys returned by this module's.
|
||||
low_cpu_mem_mode (bool): whether to load the model in low cpu memory mode. If false, it will use RAM cache to accelerate loading. Default: True.
|
||||
num_threads (int): number of threads to use when loading the model. Only useful when disabling low cpu mem mode. Default: 1.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
@@ -303,7 +338,14 @@ class CheckpointIO(ABC):
|
||||
# ========================================================
|
||||
|
||||
@abstractmethod
|
||||
def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str):
|
||||
def load_sharded_optimizer(
|
||||
self,
|
||||
optimizer: Optimizer,
|
||||
index_file_path: str,
|
||||
prefix: str,
|
||||
low_cpu_mem_mode: bool = True,
|
||||
num_threads: int = 1,
|
||||
):
|
||||
"""
|
||||
Load optimizer from sharded checkpoint.
|
||||
|
||||
@@ -311,16 +353,22 @@ class CheckpointIO(ABC):
|
||||
optimizer (Optimizer): optimizer to be loaded.
|
||||
index_file_path (str): checkpoint path. It should be path to the .index.json file or a path to a directory which contains a .index.json file.
|
||||
prefix (str): prefix for the optimizer checkpoint.
|
||||
low_cpu_mem_mode (bool): whether to load the model in low cpu memory mode. If false, it will use RAM cache to accelerate loading. Default: True.
|
||||
num_threads (int): number of threads to use when loading the model. Only useful when disabling low cpu mem mode. Default: 1.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path):
|
||||
def load_unsharded_optimizer(
|
||||
self, optimizer: Optimizer, checkpoint: Path, low_cpu_mem_mode: bool = True, num_threads: int = 1
|
||||
):
|
||||
"""
|
||||
Load optimizer from unsharded checkpoint.
|
||||
|
||||
Args:
|
||||
optimizer (Optimizer): optimizer to be loaded.
|
||||
checkpoint (str): checkpoint path. It should be a single file path pointing to a model weight binary.
|
||||
low_cpu_mem_mode (bool): whether to load the model in low cpu memory mode. If false, it will use RAM cache to accelerate loading. Default: True.
|
||||
num_threads (int): number of threads to use when loading the model. Only useful when disabling low cpu mem mode. Default: 1.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
|
Reference in New Issue
Block a user