mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 03:20:52 +00:00
[shardformer] support sharded checkpoint IO for models of HybridParallelPlugin (#4506)
* add APIs * implement save_sharded_model * add test for hybrid checkpointio * implement naive loading for sharded model * implement efficient sharded model loading * open a new file for hybrid checkpoint_io * small fix * fix circular importing * fix docstring * arrange arguments and apis * small fix
This commit is contained in:
@@ -10,6 +10,7 @@ import torch.nn as nn
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, Module
|
||||
|
||||
from colossalai.checkpoint_io.utils import gather_distributed_param
|
||||
from colossalai.tensor.d_tensor import (
|
||||
distribute_tensor,
|
||||
distribute_tensor_with_customization,
|
||||
@@ -56,13 +57,7 @@ class ParallelModule(nn.Module, ABC):
|
||||
"""
|
||||
for name, param in self._parameters.items():
|
||||
if param is not None:
|
||||
param_ = param if keep_vars else param.detach()
|
||||
if is_distributed_tensor(param_):
|
||||
destination[prefix + name] = to_global(param_)
|
||||
elif is_customized_distributed_tensor(param_):
|
||||
destination[prefix + name] = to_global_for_customized_distributed_tensor(param_)
|
||||
else:
|
||||
destination[prefix + name] = param_
|
||||
destination[prefix + name] = gather_distributed_param(param, keep_vars=keep_vars)
|
||||
|
||||
for name, buf in self._buffers.items():
|
||||
if buf is not None and name not in self._non_persistent_buffers_set:
|
||||
|
Reference in New Issue
Block a user