mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 10:34:41 +00:00
[checkpointio] Sharded Optimizer Checkpoint for Gemini Plugin (#4302)
* sharded optimizer checkpoint for gemini plugin * modify test to reduce testing time * update doc * fix bug when keep_gatherd is true under GeminiPlugin
This commit is contained in:
@@ -3,7 +3,7 @@ import copy
|
||||
import gc
|
||||
import math
|
||||
import warnings
|
||||
from typing import Any, Dict, Set, Tuple
|
||||
from typing import Any, Dict, Iterator, OrderedDict, Set, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@@ -11,8 +11,10 @@ from torch.nn import Parameter
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from colossalai.amp.naive_amp.mixed_precision_mixin import BF16MixedPrecisionMixin, FP16MixedPrecisionMixin
|
||||
from colossalai.checkpoint_io.utils import calculate_tensor_size
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn.optimizer import ColossalaiOptimizer, CPUAdam, FusedAdam, HybridAdam
|
||||
from colossalai.tensor.d_tensor import is_distributed_tensor
|
||||
from colossalai.utils import disposable, get_current_device, is_ddp_ignored
|
||||
|
||||
from .chunk import Chunk, ChunkManager
|
||||
@@ -360,10 +362,12 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
||||
|
||||
begin_in_chunk, end_in_chunk = self.param_to_range[fake_param]
|
||||
chunk_offset = begin_in_chunk
|
||||
shard_offset = begin_in_chunk + chunk.shard_begin - param_info.offset
|
||||
if chunk.keep_gathered:
|
||||
shard_offset = 0
|
||||
else:
|
||||
shard_offset = begin_in_chunk + chunk.shard_begin - param_info.offset
|
||||
shard_size = end_in_chunk - begin_in_chunk
|
||||
assert chunk_offset >= 0 and shard_offset >= 0
|
||||
|
||||
return chunk_offset, shard_offset, shard_size
|
||||
|
||||
def collect_states(self, param_id: int, only_rank_0: bool = True) -> dict:
|
||||
@@ -427,7 +431,8 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
||||
dtype=torch.float32,
|
||||
requires_grad=False).cpu()
|
||||
else:
|
||||
collected_states[state_name] = states[state_name].detach().clone().to(torch.float32).cpu()
|
||||
state_tensor = states[state_name].detach().clone().to(torch.float32).cpu()
|
||||
collected_states[state_name] = torch.reshape(state_tensor, param.shape)
|
||||
return collected_states
|
||||
|
||||
# Check whether the param with given id is managed by current process.
|
||||
@@ -536,6 +541,31 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
||||
target_segment.copy_(compacted_states[next_state_offset:next_state_offset + shard_size])
|
||||
next_state_offset += shard_size
|
||||
|
||||
def get_param_groups_for_saving(self) -> list:
|
||||
'''
|
||||
Return the param_groups in Pytorch format when saving to checkpoint.
|
||||
'''
|
||||
|
||||
param_groups = copy.deepcopy(self.param_groups_backup)
|
||||
|
||||
# To be compatible with pytorch checkpointing,
|
||||
# store extra hyperparameters used by pytorch Adam optimizer.
|
||||
torch_special_hyperparameters = {
|
||||
'amsgrad': False,
|
||||
'maximize': False,
|
||||
'foreach': None,
|
||||
'capturable': False,
|
||||
'differentiable': False,
|
||||
'fused': False
|
||||
}
|
||||
|
||||
for group in param_groups:
|
||||
for k, v in torch_special_hyperparameters.items():
|
||||
if k not in group:
|
||||
group[k] = v
|
||||
|
||||
return param_groups
|
||||
|
||||
def state_dict(self, only_rank_0: bool = True) -> dict:
|
||||
"""
|
||||
Args:
|
||||
@@ -555,21 +585,7 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
||||
so it should be called only when memory resources are abundant.
|
||||
"""
|
||||
state_dict = {}
|
||||
state_dict['param_groups'] = copy.deepcopy(self.param_groups_backup)
|
||||
|
||||
torch_special_hyperparameters = {
|
||||
'amsgrad': False,
|
||||
'maximize': False,
|
||||
'foreach': None,
|
||||
'capturable': False,
|
||||
'differentiable': False,
|
||||
'fused': False
|
||||
}
|
||||
|
||||
for group in state_dict['param_groups']:
|
||||
for k, v in torch_special_hyperparameters.items():
|
||||
if k not in group:
|
||||
group[k] = v
|
||||
state_dict['param_groups'] = self.get_param_groups_for_saving()
|
||||
|
||||
# Collect optimizer states.
|
||||
state_dict['state'] = dict()
|
||||
@@ -634,8 +650,24 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
||||
del v # clean loaded states
|
||||
self.optim.state[fake_param].update(updated_states)
|
||||
|
||||
def load_param_states(self, param_states: dict):
|
||||
"""Loads param states from a state_dict. The param_states can be complete or sharded.
|
||||
During loading, filter out the part of states not considered by current process.
|
||||
|
||||
Args:
|
||||
param_states (dict): A mapping from param_id to its states.
|
||||
"""
|
||||
for param_id, states in param_states.items():
|
||||
if param_id in self.id_to_fake_params:
|
||||
self.load_single_param_states(param_id, states)
|
||||
|
||||
def optimizer_loading_epilogue(self):
|
||||
# Epilogue when loading state_dict to pytorch optimizer.
|
||||
self.optim._hook_for_profile() # To support multiprocessing pickle/unpickle.
|
||||
self.optim.defaults.setdefault('differentiable', False)
|
||||
|
||||
def load_state_dict(self, state_dict: dict):
|
||||
"""Loads optimizer state from whole optimizer state_dict.
|
||||
"""Loads optimizer state from complete optimizer state_dict.
|
||||
During loading, filter out the part of states not considered by current process.
|
||||
|
||||
Args:
|
||||
@@ -643,17 +675,71 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
||||
from a call to :meth:`state_dict`.
|
||||
"""
|
||||
assert 'param_groups' in state_dict
|
||||
assert 'state' in state_dict
|
||||
self.load_param_groups(state_dict['param_groups'])
|
||||
self.load_param_states(state_dict['state'])
|
||||
self.optimizer_loading_epilogue()
|
||||
|
||||
state = state_dict['state']
|
||||
def state_shard(self,
|
||||
prefix: str = '',
|
||||
max_shard_size: int = 1024,
|
||||
only_rank_0: bool = True) -> Iterator[Tuple[OrderedDict, int]]:
|
||||
"""Returns dictionaries containing shards of optimizer states one by one.
|
||||
The max size of each dictionary shard is specified by ``max_shard_size``.
|
||||
|
||||
for param_id, param_states in state.items():
|
||||
if param_id in self.id_to_fake_params:
|
||||
self.load_single_param_states(param_id, param_states)
|
||||
Args:
|
||||
prefix (str, optional): the prefix for states. Default to ''.
|
||||
max_shard_size (int, optional): max size of state dict shard (in MB). Defaults to 1024.
|
||||
only_rank_0 (bool, optional): a boolean value indicating whether the state_dict is collected
|
||||
only on rank 0, dafault to True.
|
||||
|
||||
# Epilogue for pytorch optimizer.
|
||||
self.optim._hook_for_profile() # To support multiprocessing pickle/unpickle.
|
||||
self.optim.defaults.setdefault('differentiable', False)
|
||||
Yields:
|
||||
Iterator[OrderedDict]: A generator of state dict shard of optimizer states.
|
||||
"""
|
||||
|
||||
current_block = {}
|
||||
current_block_size = 0
|
||||
|
||||
for param_id in self.id_to_real_params.keys():
|
||||
|
||||
dist.barrier()
|
||||
state = self.collect_states(param_id=param_id, only_rank_0=only_rank_0)
|
||||
|
||||
ret_block = None
|
||||
ret_block_size = 0
|
||||
|
||||
# A state might contain more than one tensors.
|
||||
# e.g. each Adam state includes: 'step', 'exp_avg', 'exp_avg_sq'
|
||||
state_size = 0
|
||||
isDTensor = False
|
||||
for state_tensor in state.values():
|
||||
|
||||
# When state_tensor is not of Tensor class,
|
||||
# e.g., a SGD optimizer with momentum set to 0 can have None as state
|
||||
# The calculation of tensor size should be skipped to avoid error.
|
||||
if not isinstance(state_tensor, torch.Tensor):
|
||||
continue
|
||||
|
||||
# If the states are stored as DTensors, mark isDTensor as true.
|
||||
if is_distributed_tensor(state_tensor):
|
||||
isDTensor = True
|
||||
state_size += calculate_tensor_size(state_tensor)
|
||||
|
||||
if not isDTensor:
|
||||
|
||||
if current_block_size + state_size > max_shard_size and current_block_size > 0:
|
||||
ret_block = current_block
|
||||
ret_block_size = current_block_size
|
||||
current_block = {}
|
||||
current_block_size = 0
|
||||
|
||||
current_block[param_id] = state
|
||||
current_block_size += state_size
|
||||
|
||||
if ret_block != None:
|
||||
yield ret_block, ret_block_size
|
||||
|
||||
yield current_block, current_block_size
|
||||
|
||||
|
||||
class GeminiAdamOptimizer(ZeroOptimizer):
|
||||
|
Reference in New Issue
Block a user