[booster] gemini plugin support shard checkpoint (#3610)

* gemini plugin add shard checkpoint save/load

* gemini plugin add shard checkpoint save/load

* gemini plugin add shard checkpoint save/load

* gemini plugin add shard checkpoint save/load

* gemini plugin add shard checkpoint save/load

* gemini plugin add shard checkpoint save/load

* gemini plugin add shard checkpoint save/load

* gemini plugin add shard checkpoint save/load

* gemini plugin add shard checkpoint save/load

* gemini plugin add shard checkpoint save/load

* gemini plugin add shard checkpoint save/load

* gemini plugin add shard checkpoint save/load

* gemini plugin add shard checkpoint save/load

* gemini plugin add shard checkpoint save/load

* gemini plugin support shard checkpoint

* [API Refactoring]gemini plugin support shard checkpoint

* [API Refactoring]gemini plugin support shard checkpoint

* [API Refactoring]gemini plugin support shard checkpoint

* [API Refactoring]gemini plugin support shard checkpoint

* [API Refactoring]gemini plugin support shard checkpoint

* [API Refactoring]gemini plugin support shard checkpoint

* [API Refactoring]gemini plugin support shard checkpoint

* [API Refactoring]gemini plugin support shard checkpoint

* [API Refactoring]gemini plugin support shard checkpoint

* [API Refactoring]gemini plugin support shard checkpoint

* [API Refactoring]gemini plugin support shard checkpoint

* [API Refactoring]gemini plugin support shard checkpoint

* [API Refactoring]gemini plugin support shard checkpoint

---------

Co-authored-by: luchen <luchen@luchendeMBP.lan>
Co-authored-by: luchen <luchen@luchendeMacBook-Pro.local>
This commit is contained in:
jiangmingyan
2023-05-05 14:37:21 +08:00
committed by GitHub
parent 0f785cb1f3
commit 307894f74d
9 changed files with 268 additions and 96 deletions

View File

@@ -2,7 +2,7 @@ import itertools
from collections import OrderedDict
from contextlib import nullcontext
from functools import partial
from typing import Dict, Iterator, List, Optional, Union
from typing import Dict, Iterator, List, Optional, Union, Tuple, Set
import torch
import torch.distributed as dist
@@ -96,8 +96,35 @@ class ZeroDDP(ColoDDP):
param_name = m_name + '.' + p_name if m_name else p_name
self.name2param[param_name] = p_var
super().__init__(module, process_group=ColoProcessGroup())
self._non_persistent_buffers_set=self._get_non_persistent_buffers_set(module)
self._cast_buffers()
def _get_non_persistent_buffers_set(self, module, memo: Optional[Set[nn.Module]] = None, prefix: str = '', remove_duplicate: bool = True):
r"""
Args:
memo: a memo to store the set of modules already added to the result
prefix: a prefix that will be added to the name of the module
remove_duplicate: whether to remove the duplicated module instances in the result
or not
"""
if memo is None:
memo = set()
self_non_persistent_set = set()
if module not in memo:
if remove_duplicate:
memo.add(module)
self_non_persistent_set = set(map(lambda key: prefix + ('.' if prefix else '') + key, module._non_persistent_buffers_set))
for name, sub_module in module._modules.items():
if sub_module is None:
continue
submodule_prefix = prefix + ('.' if prefix else '') + name
child_non_persistent_set = self._get_non_persistent_buffers_set(sub_module, memo, submodule_prefix, remove_duplicate)
self_non_persistent_set = set.union(self_non_persistent_set, child_non_persistent_set)
return self_non_persistent_set
def _post_forward(self):
"""This function is only triggered for inference.
"""
@@ -604,7 +631,7 @@ class ZeroDDP(ColoDDP):
keep_vars: bool = False,
max_shard_size: int = 1024,
only_rank_0: bool = True,
dtype: torch.dtype = torch.float16) -> Iterator[OrderedDict]:
dtype: torch.dtype = torch.float16) -> Iterator[Tuple[OrderedDict, int]]:
"""Returns dictionaries containing a whole state of the module one by one. The max size of dictionary shard is specified by ``max_shard_size``.
Both parameters and persistent buffers (e.g. running averages) are included.
@@ -644,9 +671,9 @@ class ZeroDDP(ColoDDP):
gathered_param_buffer.update(self._get_chunk_to_save_data(chunk, only_rank_0, dtype))
gathered_param = gathered_param_buffer.pop(fp32_param)
block = sharder.append(prefix + name, gathered_param)
block, block_size = sharder.append(prefix + name, gathered_param)
if block is not None:
yield block
yield block, block_size
del fp16_to_fp32
del gathered_param_buffer
@@ -655,19 +682,19 @@ class ZeroDDP(ColoDDP):
for name, buf in self.named_buffers():
if buf is not None and name not in self._non_persistent_buffers_set:
buffer = buf if keep_vars else buf.detach()
block = sharder.append(prefix + name, buffer)
block, block_size = sharder.append(prefix + name, buffer)
if block is not None:
yield block
yield block, block_size
# save extra states
extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
if getattr(self.__class__, "get_extra_state",
torch.nn.Module.get_extra_state) is not torch.nn.Module.get_extra_state:
extra_state = self.get_extra_state()
block = sharder.append(extra_state_key, extra_state)
block, block_size = sharder.append(extra_state_key, extra_state)
if block is not None:
yield block
yield block, block_size
yield sharder.current_block
yield sharder.current_block, sharder.current_block_size
class _StateDictSharder:
@@ -677,16 +704,18 @@ class _StateDictSharder:
self.current_block = OrderedDict()
self.current_block_size = 0
def append(self, name: str, tensor: torch.Tensor) -> Optional[OrderedDict]:
def append(self, name: str, tensor: torch.Tensor) -> Tuple[Optional[OrderedDict], int]:
tensor_size = calculate_tensor_size(tensor)
ret_block = None
ret_block_size = 0
if self.current_block_size + tensor_size > self.max_shard_size:
ret_block = self.current_block
ret_block_size = self.current_block_size
self.current_block = OrderedDict()
self.current_block_size = 0
self.current_block[name] = tensor
self.current_block_size += tensor_size
return ret_block
return ret_block, ret_block_size
class GeminiDDP(ZeroDDP):