mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 04:24:47 +00:00
[shardformer] support sharded checkpoint IO for models of HybridParallelPlugin (#4506)
* add APIs * implement save_sharded_model * add test for hybrid checkpointio * implement naive loading for sharded model * implement efficient sharded model loading * open a new file for hybrid checkpoint_io * small fix * fix circular importing * fix docstring * arrange arguments and apis * small fix
This commit is contained in:
@@ -13,7 +13,12 @@ from torch.optim import Optimizer
|
||||
|
||||
from colossalai.interface import OptimizerWrapper
|
||||
from colossalai.nn.optimizer import ColossalaiOptimizer
|
||||
from colossalai.tensor.d_tensor import is_distributed_tensor
|
||||
from colossalai.tensor.d_tensor import (
|
||||
is_customized_distributed_tensor,
|
||||
is_distributed_tensor,
|
||||
to_global,
|
||||
to_global_for_customized_distributed_tensor,
|
||||
)
|
||||
|
||||
SAFE_WEIGHTS_NAME = "model.safetensors"
|
||||
WEIGHTS_NAME = "pytorch_model.bin"
|
||||
@@ -88,8 +93,28 @@ def is_safetensor_checkpoint(checkpoint_file_path: str) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def gather_distributed_param(param: torch.Tensor, keep_vars: bool = False):
|
||||
"""
|
||||
Gather the complete parameter for saving if passed in param is distributed.
|
||||
|
||||
Args:
|
||||
param (torch.Tensor): A model parameter, might be d_tensor.
|
||||
keep_vars (bool, optional): Whether to return the parameter in calculation graph. Defaults to False.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: the complete parameter
|
||||
"""
|
||||
param_ = param if keep_vars else param.detach()
|
||||
if is_distributed_tensor(param_):
|
||||
return to_global(param_)
|
||||
elif is_customized_distributed_tensor(param_):
|
||||
return to_global_for_customized_distributed_tensor(param_)
|
||||
else:
|
||||
return param_
|
||||
|
||||
|
||||
# ======================================
|
||||
# Helper functions for saving shard file
|
||||
# Helper classes and functions for saving shard file
|
||||
# ======================================
|
||||
def unwrap_optimizer(optimizer: OptimizerWrapper):
|
||||
'''
|
||||
@@ -104,6 +129,31 @@ def unwrap_optimizer(optimizer: OptimizerWrapper):
|
||||
return unwrapped_optim
|
||||
|
||||
|
||||
class StateDictSharder:
|
||||
|
||||
def __init__(self, size_per_shard: int) -> None:
|
||||
self.max_shard_size = size_per_shard
|
||||
self.current_block = OrderedDict()
|
||||
self.current_block_size = 0
|
||||
|
||||
def append(self, name: str, tensor: torch.Tensor) -> Tuple[Optional[OrderedDict], int]:
|
||||
tensor_size = calculate_tensor_size(tensor)
|
||||
ret_block = None
|
||||
ret_block_size = 0
|
||||
|
||||
# before we return the current block and create a new block,
|
||||
# we need to ensure that the current block is not empty
|
||||
if self.current_block_size + tensor_size > self.max_shard_size and self.current_block_size > 0:
|
||||
ret_block = self.current_block
|
||||
ret_block_size = self.current_block_size
|
||||
self.current_block = OrderedDict()
|
||||
self.current_block_size = 0
|
||||
|
||||
self.current_block[name] = tensor
|
||||
self.current_block_size += tensor_size
|
||||
return ret_block, ret_block_size
|
||||
|
||||
|
||||
def save_state_dict_shards(sharded_state_dict: Iterator[Tuple[OrderedDict, int]],
|
||||
checkpoint: str,
|
||||
index_file: "CheckpointIndexFile",
|
||||
@@ -126,9 +176,10 @@ def save_state_dict_shards(sharded_state_dict: Iterator[Tuple[OrderedDict, int]]
|
||||
|
||||
total_size = 0
|
||||
for idx, shard_pair in enumerate(sharded_state_dict):
|
||||
if not is_master:
|
||||
continue
|
||||
shard, current_size = shard_pair
|
||||
if not is_master:
|
||||
del shard
|
||||
continue
|
||||
shard_file = get_shard_filename(base_filename, idx)
|
||||
total_size = total_size + current_size
|
||||
for key in shard.keys():
|
||||
@@ -137,6 +188,7 @@ def save_state_dict_shards(sharded_state_dict: Iterator[Tuple[OrderedDict, int]]
|
||||
|
||||
# Only save on master rank.
|
||||
save_state_dict(shard, checkpoint_file_path, use_safetensors=use_safetensors)
|
||||
del shard
|
||||
|
||||
return total_size
|
||||
|
||||
|
Reference in New Issue
Block a user