mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 18:40:28 +00:00
[shardformer] support sharded checkpoint IO for models of HybridParallelPlugin (#4506)
* add APIs * implement save_sharded_model * add test for hybrid checkpointio * implement naive loading for sharded model * implement efficient sharded model loading * open a new file for hybrid checkpoint_io * small fix * fix circular importing * fix docstring * arrange arguments and apis * small fix
This commit is contained in:
@@ -8,7 +8,7 @@ import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.checkpoint_io.utils import calculate_tensor_size
|
||||
from colossalai.checkpoint_io.utils import StateDictSharder
|
||||
from colossalai.lazy import LazyTensor
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn.parallel.data_parallel import ColoDDP, _cast_float, free_storage
|
||||
@@ -657,7 +657,7 @@ class ZeroDDP(ColoDDP):
|
||||
Yields:
|
||||
Iterator[OrderedDict]: A generator of state dict shard
|
||||
"""
|
||||
sharder = _StateDictSharder(max_shard_size)
|
||||
sharder = StateDictSharder(max_shard_size)
|
||||
|
||||
# get the mapping between copies and fp16 parameters
|
||||
fp16_to_fp32 = dict()
|
||||
@@ -705,30 +705,6 @@ class ZeroDDP(ColoDDP):
|
||||
yield sharder.current_block, sharder.current_block_size
|
||||
|
||||
|
||||
class _StateDictSharder:
|
||||
|
||||
def __init__(self, max_shard_size: int) -> None:
|
||||
self.max_shard_size = max_shard_size
|
||||
self.current_block = OrderedDict()
|
||||
self.current_block_size = 0
|
||||
|
||||
def append(self, name: str, tensor: torch.Tensor) -> Tuple[Optional[OrderedDict], int]:
|
||||
tensor_size = calculate_tensor_size(tensor)
|
||||
ret_block = None
|
||||
ret_block_size = 0
|
||||
|
||||
# before we return the current block and create a new block,
|
||||
# we need to ensure that the current block is not empty
|
||||
if self.current_block_size + tensor_size > self.max_shard_size and self.current_block_size > 0:
|
||||
ret_block = self.current_block
|
||||
ret_block_size = self.current_block_size
|
||||
self.current_block = OrderedDict()
|
||||
self.current_block_size = 0
|
||||
self.current_block[name] = tensor
|
||||
self.current_block_size += tensor_size
|
||||
return ret_block, ret_block_size
|
||||
|
||||
|
||||
class GeminiDDP(ZeroDDP):
|
||||
|
||||
def __init__(self,
|
||||
|
Reference in New Issue
Block a user