[shardformer] support sharded checkpoint IO for models of HybridParallelPlugin (#4506)

* add APIs

* implement save_sharded_model

* add test for hybrid checkpointio

* implement naive loading for sharded model

* implement efficient sharded model loading

* open a new file for hybrid checkpoint_io

* small fix

* fix circular importing

* fix docstring

* arrange arguments and apis

* small fix
This commit is contained in:
Baizhou Zhang
2023-08-25 22:04:57 +08:00
committed by GitHub
parent de8a65babc
commit 44eab2b27f
7 changed files with 497 additions and 40 deletions

View File

@@ -1,5 +1,6 @@
from .checkpoint_io_base import CheckpointIO
from .general_checkpoint_io import GeneralCheckpointIO
from .hybrid_parallel_checkpoint_io import HypridParallelCheckpointIO
from .index_file import CheckpointIndexFile
__all__ = ['CheckpointIO', 'CheckpointIndexFile', 'GeneralCheckpointIO']
__all__ = ['CheckpointIO', 'CheckpointIndexFile', 'GeneralCheckpointIO', 'HybridParallelCheckpointIO']