[shardformer] support sharded checkpoint IO for models of HybridParallelPlugin (#4506)

* add APIs

* implement save_sharded_model

* add test for hybrid checkpointio

* implement naive loading for sharded model

* implement efficient sharded model loading

* open a new file for hybrid checkpoint_io

* small fix

* fix circular importing

* fix docstring

* arrange arguments and apis

* small fix
This commit is contained in:
Baizhou Zhang
2023-08-25 22:04:57 +08:00
committed by GitHub
parent de8a65babc
commit 44eab2b27f
7 changed files with 497 additions and 40 deletions

View File

@@ -10,6 +10,7 @@ import torch.nn as nn
from torch.distributed import ProcessGroup
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, Module
from colossalai.checkpoint_io.utils import gather_distributed_param
from colossalai.tensor.d_tensor import (
distribute_tensor,
distribute_tensor_with_customization,
@@ -56,13 +57,7 @@ class ParallelModule(nn.Module, ABC):
"""
for name, param in self._parameters.items():
if param is not None:
param_ = param if keep_vars else param.detach()
if is_distributed_tensor(param_):
destination[prefix + name] = to_global(param_)
elif is_customized_distributed_tensor(param_):
destination[prefix + name] = to_global_for_customized_distributed_tensor(param_)
else:
destination[prefix + name] = param_
destination[prefix + name] = gather_distributed_param(param, keep_vars=keep_vars)
for name, buf in self._buffers.items():
if buf is not None and name not in self._non_persistent_buffers_set: