[checkpointio] support non blocking pin load (#6172)

* [checkpointio] support non blocking pin load

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Hongxin Liu
2024-12-25 17:03:25 +08:00
committed by GitHub
parent 836992438f
commit af06d162cf
15 changed files with 484 additions and 174 deletions

View File

@@ -85,7 +85,12 @@ class CheckpointIO(ABC):
self._sync_io()
def load_model(
self, model: Union[nn.Module, ModelWrapper], checkpoint: str, strict: bool = True
self,
model: Union[nn.Module, ModelWrapper],
checkpoint: str,
strict: bool = True,
low_cpu_mem_mode: bool = True,
num_threads: int = 1,
) -> Union[nn.Module, ModelWrapper]:
"""
Load model from checkpoint.
@@ -100,6 +105,8 @@ class CheckpointIO(ABC):
Distributed tensors cannot be loaded directly unless gathered offline via our CLI.
strict (bool): whether to strictly enforce that the param name in
the checkpoint match the keys returned by this module's.
low_cpu_mem_mode (bool): whether to load the model in low cpu memory mode. If false, it will use RAM cache to accelerate loading. Default: True.
num_threads (int): number of threads to use when loading the model. Only useful when disabling low cpu mem mode. Default: 1.
"""
# since we only support loaded sharded and unsharded weight format
# containing no distributed tensors, dtensor -> full tensor conversion
@@ -111,17 +118,25 @@ class CheckpointIO(ABC):
origin_model = model
if index_file_exists:
self.load_sharded_model(model, index_file_path, strict)
self.load_sharded_model(
model, index_file_path, strict, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
)
else:
path = Path(checkpoint, SAFE_WEIGHTS_NAME)
if path.is_file():
self.load_unsharded_model(model, str(path), strict)
self.load_unsharded_model(
model, str(path), strict, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
)
else:
path = Path(checkpoint, WEIGHTS_NAME)
if path.is_file():
self.load_unsharded_model(model, str(path), strict)
self.load_unsharded_model(
model, str(path), strict, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
)
else:
self.load_unsharded_model(model, checkpoint, strict)
self.load_unsharded_model(
model, checkpoint, strict, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
)
return origin_model
@@ -178,7 +193,14 @@ class CheckpointIO(ABC):
else:
self.save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors, use_async)
def load_optimizer(self, optimizer: Optimizer, checkpoint: str, prefix: str = None, size_per_shard: int = 1024):
def load_optimizer(
self,
optimizer: Optimizer,
checkpoint: str,
prefix: str = None,
low_cpu_mem_mode: bool = True,
num_threads: int = 1,
):
"""
Load optimizer from checkpoint.
@@ -187,7 +209,8 @@ class CheckpointIO(ABC):
checkpoint (str): checkpoint path. This value is made compatibility with the model checkpoints in the
prefix (str, optional): A prefix added to parameter and buffer
names to compose the keys in state_dict. Defaults to None.
size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024.
low_cpu_mem_mode (bool): whether to load the model in low cpu memory mode. If false, it will use RAM cache to accelerate loading. Default: True.
num_threads (int): number of threads to use when loading the model. Only useful when disabling low cpu mem mode. Default: 1.
"""
index_file_exists, index_file_path = has_index_file(checkpoint)
@@ -198,9 +221,13 @@ class CheckpointIO(ABC):
if index_file_exists:
# the existence of index file means it is a sharded checkpoint
self.load_sharded_optimizer(optimizer, index_file_path, prefix)
self.load_sharded_optimizer(
optimizer, index_file_path, prefix, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
)
else:
self.load_unsharded_optimizer(optimizer, checkpoint)
self.load_unsharded_optimizer(
optimizer, checkpoint, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
)
def save_optimizer(
self,
@@ -238,7 +265,9 @@ class CheckpointIO(ABC):
# Abstract methods for model loading/saving implementation
# ========================================================
@abstractmethod
def load_sharded_model(self, model: nn.Module, index_file_path: str, strict: bool):
def load_sharded_model(
self, model: nn.Module, index_file_path: str, strict: bool, low_cpu_mem_mode: bool = True, num_threads: int = 1
):
"""
Load model from sharded checkpoint.
@@ -247,10 +276,14 @@ class CheckpointIO(ABC):
index_file_path (str): checkpoint path. It should be path to the .index.json file or a path to a directory which contains a .index.json file.
strict (bool): whether to strictly enforce that the param name in
the checkpoint match the keys returned by this module's.
low_cpu_mem_mode (bool): whether to load the model in low cpu memory mode. If false, it will use RAM cache to accelerate loading. Default: True.
num_threads (int): number of threads to use when loading the model. Only useful when disabling low cpu mem mode. Default: 1.
"""
@abstractmethod
def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool):
def load_unsharded_model(
self, model: nn.Module, checkpoint: str, strict: bool, low_cpu_mem_mode: bool = True, num_threads: int = 1
):
"""
Load model from unsharded checkpoint.
@@ -259,6 +292,8 @@ class CheckpointIO(ABC):
checkpoint (str): checkpoint path. It should be a single file path pointing to a model weight binary.
strict (bool): whether to strictly enforce that the param name in
the checkpoint match the keys returned by this module's.
low_cpu_mem_mode (bool): whether to load the model in low cpu memory mode. If false, it will use RAM cache to accelerate loading. Default: True.
num_threads (int): number of threads to use when loading the model. Only useful when disabling low cpu mem mode. Default: 1.
"""
@abstractmethod
@@ -303,7 +338,14 @@ class CheckpointIO(ABC):
# ========================================================
@abstractmethod
def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str):
def load_sharded_optimizer(
self,
optimizer: Optimizer,
index_file_path: str,
prefix: str,
low_cpu_mem_mode: bool = True,
num_threads: int = 1,
):
"""
Load optimizer from sharded checkpoint.
@@ -311,16 +353,22 @@ class CheckpointIO(ABC):
optimizer (Optimizer): optimizer to be loaded.
index_file_path (str): checkpoint path. It should be path to the .index.json file or a path to a directory which contains a .index.json file.
prefix (str): prefix for the optimizer checkpoint.
low_cpu_mem_mode (bool): whether to load the model in low cpu memory mode. If false, it will use RAM cache to accelerate loading. Default: True.
num_threads (int): number of threads to use when loading the model. Only useful when disabling low cpu mem mode. Default: 1.
"""
@abstractmethod
def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path):
def load_unsharded_optimizer(
self, optimizer: Optimizer, checkpoint: Path, low_cpu_mem_mode: bool = True, num_threads: int = 1
):
"""
Load optimizer from unsharded checkpoint.
Args:
optimizer (Optimizer): optimizer to be loaded.
checkpoint (str): checkpoint path. It should be a single file path pointing to a model weight binary.
low_cpu_mem_mode (bool): whether to load the model in low cpu memory mode. If false, it will use RAM cache to accelerate loading. Default: True.
num_threads (int): number of threads to use when loading the model. Only useful when disabling low cpu mem mode. Default: 1.
"""
@abstractmethod