[checkpointio] Sharded Optimizer Checkpoint for Gemini Plugin (#4302)

* sharded optimizer checkpoint for gemini plugin

* modify test to reduce testing time

* update doc

* fix bug when keep_gatherd is true under GeminiPlugin
This commit is contained in:
Baizhou Zhang
2023-07-21 14:39:01 +08:00
committed by GitHub
parent fc5cef2c79
commit c6f6005990
12 changed files with 289 additions and 84 deletions

View File

@@ -3,7 +3,7 @@ import copy
import gc
import math
import warnings
from typing import Any, Dict, Set, Tuple
from typing import Any, Dict, Iterator, OrderedDict, Set, Tuple
import torch
import torch.distributed as dist
@@ -11,8 +11,10 @@ from torch.nn import Parameter
from torch.optim import Optimizer
from colossalai.amp.naive_amp.mixed_precision_mixin import BF16MixedPrecisionMixin, FP16MixedPrecisionMixin
from colossalai.checkpoint_io.utils import calculate_tensor_size
from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import ColossalaiOptimizer, CPUAdam, FusedAdam, HybridAdam
from colossalai.tensor.d_tensor import is_distributed_tensor
from colossalai.utils import disposable, get_current_device, is_ddp_ignored
from .chunk import Chunk, ChunkManager
@@ -360,10 +362,12 @@ class ZeroOptimizer(ColossalaiOptimizer):
begin_in_chunk, end_in_chunk = self.param_to_range[fake_param]
chunk_offset = begin_in_chunk
shard_offset = begin_in_chunk + chunk.shard_begin - param_info.offset
if chunk.keep_gathered:
shard_offset = 0
else:
shard_offset = begin_in_chunk + chunk.shard_begin - param_info.offset
shard_size = end_in_chunk - begin_in_chunk
assert chunk_offset >= 0 and shard_offset >= 0
return chunk_offset, shard_offset, shard_size
def collect_states(self, param_id: int, only_rank_0: bool = True) -> dict:
@@ -427,7 +431,8 @@ class ZeroOptimizer(ColossalaiOptimizer):
dtype=torch.float32,
requires_grad=False).cpu()
else:
collected_states[state_name] = states[state_name].detach().clone().to(torch.float32).cpu()
state_tensor = states[state_name].detach().clone().to(torch.float32).cpu()
collected_states[state_name] = torch.reshape(state_tensor, param.shape)
return collected_states
# Check whether the param with given id is managed by current process.
@@ -536,6 +541,31 @@ class ZeroOptimizer(ColossalaiOptimizer):
target_segment.copy_(compacted_states[next_state_offset:next_state_offset + shard_size])
next_state_offset += shard_size
def get_param_groups_for_saving(self) -> list:
'''
Return the param_groups in Pytorch format when saving to checkpoint.
'''
param_groups = copy.deepcopy(self.param_groups_backup)
# To be compatible with pytorch checkpointing,
# store extra hyperparameters used by pytorch Adam optimizer.
torch_special_hyperparameters = {
'amsgrad': False,
'maximize': False,
'foreach': None,
'capturable': False,
'differentiable': False,
'fused': False
}
for group in param_groups:
for k, v in torch_special_hyperparameters.items():
if k not in group:
group[k] = v
return param_groups
def state_dict(self, only_rank_0: bool = True) -> dict:
"""
Args:
@@ -555,21 +585,7 @@ class ZeroOptimizer(ColossalaiOptimizer):
so it should be called only when memory resources are abundant.
"""
state_dict = {}
state_dict['param_groups'] = copy.deepcopy(self.param_groups_backup)
torch_special_hyperparameters = {
'amsgrad': False,
'maximize': False,
'foreach': None,
'capturable': False,
'differentiable': False,
'fused': False
}
for group in state_dict['param_groups']:
for k, v in torch_special_hyperparameters.items():
if k not in group:
group[k] = v
state_dict['param_groups'] = self.get_param_groups_for_saving()
# Collect optimizer states.
state_dict['state'] = dict()
@@ -634,8 +650,24 @@ class ZeroOptimizer(ColossalaiOptimizer):
del v # clean loaded states
self.optim.state[fake_param].update(updated_states)
def load_param_states(self, param_states: dict):
"""Loads param states from a state_dict. The param_states can be complete or sharded.
During loading, filter out the part of states not considered by current process.
Args:
param_states (dict): A mapping from param_id to its states.
"""
for param_id, states in param_states.items():
if param_id in self.id_to_fake_params:
self.load_single_param_states(param_id, states)
def optimizer_loading_epilogue(self):
# Epilogue when loading state_dict to pytorch optimizer.
self.optim._hook_for_profile() # To support multiprocessing pickle/unpickle.
self.optim.defaults.setdefault('differentiable', False)
def load_state_dict(self, state_dict: dict):
"""Loads optimizer state from whole optimizer state_dict.
"""Loads optimizer state from complete optimizer state_dict.
During loading, filter out the part of states not considered by current process.
Args:
@@ -643,17 +675,71 @@ class ZeroOptimizer(ColossalaiOptimizer):
from a call to :meth:`state_dict`.
"""
assert 'param_groups' in state_dict
assert 'state' in state_dict
self.load_param_groups(state_dict['param_groups'])
self.load_param_states(state_dict['state'])
self.optimizer_loading_epilogue()
state = state_dict['state']
def state_shard(self,
prefix: str = '',
max_shard_size: int = 1024,
only_rank_0: bool = True) -> Iterator[Tuple[OrderedDict, int]]:
"""Returns dictionaries containing shards of optimizer states one by one.
The max size of each dictionary shard is specified by ``max_shard_size``.
for param_id, param_states in state.items():
if param_id in self.id_to_fake_params:
self.load_single_param_states(param_id, param_states)
Args:
prefix (str, optional): the prefix for states. Default to ''.
max_shard_size (int, optional): max size of state dict shard (in MB). Defaults to 1024.
only_rank_0 (bool, optional): a boolean value indicating whether the state_dict is collected
only on rank 0, dafault to True.
# Epilogue for pytorch optimizer.
self.optim._hook_for_profile() # To support multiprocessing pickle/unpickle.
self.optim.defaults.setdefault('differentiable', False)
Yields:
Iterator[OrderedDict]: A generator of state dict shard of optimizer states.
"""
current_block = {}
current_block_size = 0
for param_id in self.id_to_real_params.keys():
dist.barrier()
state = self.collect_states(param_id=param_id, only_rank_0=only_rank_0)
ret_block = None
ret_block_size = 0
# A state might contain more than one tensors.
# e.g. each Adam state includes: 'step', 'exp_avg', 'exp_avg_sq'
state_size = 0
isDTensor = False
for state_tensor in state.values():
# When state_tensor is not of Tensor class,
# e.g., a SGD optimizer with momentum set to 0 can have None as state
# The calculation of tensor size should be skipped to avoid error.
if not isinstance(state_tensor, torch.Tensor):
continue
# If the states are stored as DTensors, mark isDTensor as true.
if is_distributed_tensor(state_tensor):
isDTensor = True
state_size += calculate_tensor_size(state_tensor)
if not isDTensor:
if current_block_size + state_size > max_shard_size and current_block_size > 0:
ret_block = current_block
ret_block_size = current_block_size
current_block = {}
current_block_size = 0
current_block[param_id] = state
current_block_size += state_size
if ret_block != None:
yield ret_block, ret_block_size
yield current_block, current_block_size
class GeminiAdamOptimizer(ZeroOptimizer):