[shardformer] support sharded checkpoint IO for models of HybridParallelPlugin (#4506)

* add APIs

* implement save_sharded_model

* add test for hybrid checkpointio

* implement naive loading for sharded model

* implement efficient sharded model loading

* open a new file for hybrid checkpoint_io

* small fix

* fix circular importing

* fix docstring

* arrange arguments and apis

* small fix
This commit is contained in:
Baizhou Zhang
2023-08-25 22:04:57 +08:00
committed by GitHub
parent de8a65babc
commit 44eab2b27f
7 changed files with 497 additions and 40 deletions

View File

@@ -8,7 +8,7 @@ import torch
import torch.distributed as dist
import torch.nn as nn
from colossalai.checkpoint_io.utils import calculate_tensor_size
from colossalai.checkpoint_io.utils import StateDictSharder
from colossalai.lazy import LazyTensor
from colossalai.logging import get_dist_logger
from colossalai.nn.parallel.data_parallel import ColoDDP, _cast_float, free_storage
@@ -657,7 +657,7 @@ class ZeroDDP(ColoDDP):
Yields:
Iterator[OrderedDict]: A generator of state dict shard
"""
sharder = _StateDictSharder(max_shard_size)
sharder = StateDictSharder(max_shard_size)
# get the mapping between copies and fp16 parameters
fp16_to_fp32 = dict()
@@ -705,30 +705,6 @@ class ZeroDDP(ColoDDP):
yield sharder.current_block, sharder.current_block_size
class _StateDictSharder:
def __init__(self, max_shard_size: int) -> None:
self.max_shard_size = max_shard_size
self.current_block = OrderedDict()
self.current_block_size = 0
def append(self, name: str, tensor: torch.Tensor) -> Tuple[Optional[OrderedDict], int]:
tensor_size = calculate_tensor_size(tensor)
ret_block = None
ret_block_size = 0
# before we return the current block and create a new block,
# we need to ensure that the current block is not empty
if self.current_block_size + tensor_size > self.max_shard_size and self.current_block_size > 0:
ret_block = self.current_block
ret_block_size = self.current_block_size
self.current_block = OrderedDict()
self.current_block_size = 0
self.current_block[name] = tensor
self.current_block_size += tensor_size
return ret_block, ret_block_size
class GeminiDDP(ZeroDDP):
def __init__(self,