mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 02:51:59 +00:00
[shardformer] support sharded optimizer checkpointIO of HybridParallelPlugin (#4540)
* implement sharded optimizer saving * add more param info * finish implementation of sharded optimizer saving * fix bugs in optimizer sharded saving * add pp+zero test * param group loading * greedy loading of optimizer * fix bug when loading * implement optimizer sharded saving * add optimizer test & arrange checkpointIO utils * fix gemini sharding state_dict * add verbose option * add loading of master params * fix typehint * fix master/working mapping in fp16 amp
This commit is contained in:
@@ -10,7 +10,7 @@ from torch.nn import Parameter
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from colossalai.amp.naive_amp.mixed_precision_mixin import BF16MixedPrecisionMixin, FP16MixedPrecisionMixin
|
||||
from colossalai.checkpoint_io.utils import calculate_tensor_size
|
||||
from colossalai.checkpoint_io.utils import StateDictSharder
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn.optimizer import ColossalaiOptimizer, CPUAdam, FusedAdam, HybridAdam
|
||||
from colossalai.tensor.d_tensor import is_distributed_tensor
|
||||
@@ -691,49 +691,17 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
||||
Iterator[OrderedDict]: A generator of state dict shard of optimizer states.
|
||||
"""
|
||||
|
||||
current_block = {}
|
||||
current_block_size = 0
|
||||
|
||||
sharder = StateDictSharder(max_shard_size)
|
||||
for param_id in self.id_to_real_params.keys():
|
||||
|
||||
dist.barrier()
|
||||
state = self.collect_states(param_id=param_id, only_rank_0=only_rank_0)
|
||||
|
||||
ret_block = None
|
||||
ret_block_size = 0
|
||||
block, block_size = sharder.append_optim_state(param_id, state)
|
||||
if block is not None:
|
||||
yield block, block_size
|
||||
|
||||
# A state might contain more than one tensors.
|
||||
# e.g. each Adam state includes: 'step', 'exp_avg', 'exp_avg_sq'
|
||||
state_size = 0
|
||||
isDTensor = False
|
||||
for state_tensor in state.values():
|
||||
|
||||
# When state_tensor is not of Tensor class,
|
||||
# e.g., a SGD optimizer with momentum set to 0 can have None as state
|
||||
# The calculation of tensor size should be skipped to avoid error.
|
||||
if not isinstance(state_tensor, torch.Tensor):
|
||||
continue
|
||||
|
||||
# If the states are stored as DTensors, mark isDTensor as true.
|
||||
if is_distributed_tensor(state_tensor):
|
||||
isDTensor = True
|
||||
state_size += calculate_tensor_size(state_tensor)
|
||||
|
||||
if not isDTensor:
|
||||
|
||||
if current_block_size + state_size > max_shard_size and current_block_size > 0:
|
||||
ret_block = current_block
|
||||
ret_block_size = current_block_size
|
||||
current_block = {}
|
||||
current_block_size = 0
|
||||
|
||||
current_block[param_id] = state
|
||||
current_block_size += state_size
|
||||
|
||||
if ret_block != None:
|
||||
yield ret_block, ret_block_size
|
||||
|
||||
yield current_block, current_block_size
|
||||
yield sharder.current_block, sharder.current_block_size
|
||||
|
||||
|
||||
class GeminiAdamOptimizer(ZeroOptimizer):
|
||||
|
Reference in New Issue
Block a user