[shardformer] support sharded checkpoint IO for models of HybridParallelPlugin (#4506)

* add APIs

* implement save_sharded_model

* add test for hybrid checkpointio

* implement naive loading for sharded model

* implement efficient sharded model loading

* open a new file for hybrid checkpoint_io

* small fix

* fix circular importing

* fix docstring

* arrange arguments and apis

* small fix
This commit is contained in:
Baizhou Zhang
2023-08-25 22:04:57 +08:00
committed by GitHub
parent de8a65babc
commit 44eab2b27f
7 changed files with 497 additions and 40 deletions

View File

@@ -13,7 +13,12 @@ from torch.optim import Optimizer
from colossalai.interface import OptimizerWrapper
from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.tensor.d_tensor import is_distributed_tensor
from colossalai.tensor.d_tensor import (
is_customized_distributed_tensor,
is_distributed_tensor,
to_global,
to_global_for_customized_distributed_tensor,
)
SAFE_WEIGHTS_NAME = "model.safetensors"
WEIGHTS_NAME = "pytorch_model.bin"
@@ -88,8 +93,28 @@ def is_safetensor_checkpoint(checkpoint_file_path: str) -> bool:
return False
def gather_distributed_param(param: torch.Tensor, keep_vars: bool = False):
"""
Gather the complete parameter for saving if passed in param is distributed.
Args:
param (torch.Tensor): A model parameter, might be d_tensor.
keep_vars (bool, optional): Whether to return the parameter in calculation graph. Defaults to False.
Returns:
torch.Tensor: the complete parameter
"""
param_ = param if keep_vars else param.detach()
if is_distributed_tensor(param_):
return to_global(param_)
elif is_customized_distributed_tensor(param_):
return to_global_for_customized_distributed_tensor(param_)
else:
return param_
# ======================================
# Helper functions for saving shard file
# Helper classes and functions for saving shard file
# ======================================
def unwrap_optimizer(optimizer: OptimizerWrapper):
'''
@@ -104,6 +129,31 @@ def unwrap_optimizer(optimizer: OptimizerWrapper):
return unwrapped_optim
class StateDictSharder:
def __init__(self, size_per_shard: int) -> None:
self.max_shard_size = size_per_shard
self.current_block = OrderedDict()
self.current_block_size = 0
def append(self, name: str, tensor: torch.Tensor) -> Tuple[Optional[OrderedDict], int]:
tensor_size = calculate_tensor_size(tensor)
ret_block = None
ret_block_size = 0
# before we return the current block and create a new block,
# we need to ensure that the current block is not empty
if self.current_block_size + tensor_size > self.max_shard_size and self.current_block_size > 0:
ret_block = self.current_block
ret_block_size = self.current_block_size
self.current_block = OrderedDict()
self.current_block_size = 0
self.current_block[name] = tensor
self.current_block_size += tensor_size
return ret_block, ret_block_size
def save_state_dict_shards(sharded_state_dict: Iterator[Tuple[OrderedDict, int]],
checkpoint: str,
index_file: "CheckpointIndexFile",
@@ -126,9 +176,10 @@ def save_state_dict_shards(sharded_state_dict: Iterator[Tuple[OrderedDict, int]]
total_size = 0
for idx, shard_pair in enumerate(sharded_state_dict):
if not is_master:
continue
shard, current_size = shard_pair
if not is_master:
del shard
continue
shard_file = get_shard_filename(base_filename, idx)
total_size = total_size + current_size
for key in shard.keys():
@@ -137,6 +188,7 @@ def save_state_dict_shards(sharded_state_dict: Iterator[Tuple[OrderedDict, int]]
# Only save on master rank.
save_state_dict(shard, checkpoint_file_path, use_safetensors=use_safetensors)
del shard
return total_size