mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 19:40:28 +00:00
[booster] gemini plugin support shard checkpoint (#3610)
* gemini plugin add shard checkpoint save/load * gemini plugin add shard checkpoint save/load * gemini plugin add shard checkpoint save/load * gemini plugin add shard checkpoint save/load * gemini plugin add shard checkpoint save/load * gemini plugin add shard checkpoint save/load * gemini plugin add shard checkpoint save/load * gemini plugin add shard checkpoint save/load * gemini plugin add shard checkpoint save/load * gemini plugin add shard checkpoint save/load * gemini plugin add shard checkpoint save/load * gemini plugin add shard checkpoint save/load * gemini plugin add shard checkpoint save/load * gemini plugin add shard checkpoint save/load * gemini plugin support shard checkpoint * [API Refactoring]gemini plugin support shard checkpoint * [API Refactoring]gemini plugin support shard checkpoint * [API Refactoring]gemini plugin support shard checkpoint * [API Refactoring]gemini plugin support shard checkpoint * [API Refactoring]gemini plugin support shard checkpoint * [API Refactoring]gemini plugin support shard checkpoint * [API Refactoring]gemini plugin support shard checkpoint * [API Refactoring]gemini plugin support shard checkpoint * [API Refactoring]gemini plugin support shard checkpoint * [API Refactoring]gemini plugin support shard checkpoint * [API Refactoring]gemini plugin support shard checkpoint * [API Refactoring]gemini plugin support shard checkpoint * [API Refactoring]gemini plugin support shard checkpoint --------- Co-authored-by: luchen <luchen@luchendeMBP.lan> Co-authored-by: luchen <luchen@luchendeMacBook-Pro.local>
This commit is contained in:
@@ -2,7 +2,7 @@ import itertools
|
||||
from collections import OrderedDict
|
||||
from contextlib import nullcontext
|
||||
from functools import partial
|
||||
from typing import Dict, Iterator, List, Optional, Union
|
||||
from typing import Dict, Iterator, List, Optional, Union, Tuple, Set
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@@ -96,8 +96,35 @@ class ZeroDDP(ColoDDP):
|
||||
param_name = m_name + '.' + p_name if m_name else p_name
|
||||
self.name2param[param_name] = p_var
|
||||
super().__init__(module, process_group=ColoProcessGroup())
|
||||
self._non_persistent_buffers_set=self._get_non_persistent_buffers_set(module)
|
||||
self._cast_buffers()
|
||||
|
||||
def _get_non_persistent_buffers_set(self, module, memo: Optional[Set[nn.Module]] = None, prefix: str = '', remove_duplicate: bool = True):
|
||||
|
||||
r"""
|
||||
Args:
|
||||
memo: a memo to store the set of modules already added to the result
|
||||
prefix: a prefix that will be added to the name of the module
|
||||
remove_duplicate: whether to remove the duplicated module instances in the result
|
||||
or not
|
||||
"""
|
||||
|
||||
if memo is None:
|
||||
memo = set()
|
||||
self_non_persistent_set = set()
|
||||
if module not in memo:
|
||||
if remove_duplicate:
|
||||
memo.add(module)
|
||||
self_non_persistent_set = set(map(lambda key: prefix + ('.' if prefix else '') + key, module._non_persistent_buffers_set))
|
||||
for name, sub_module in module._modules.items():
|
||||
if sub_module is None:
|
||||
continue
|
||||
submodule_prefix = prefix + ('.' if prefix else '') + name
|
||||
child_non_persistent_set = self._get_non_persistent_buffers_set(sub_module, memo, submodule_prefix, remove_duplicate)
|
||||
self_non_persistent_set = set.union(self_non_persistent_set, child_non_persistent_set)
|
||||
return self_non_persistent_set
|
||||
|
||||
|
||||
def _post_forward(self):
|
||||
"""This function is only triggered for inference.
|
||||
"""
|
||||
@@ -604,7 +631,7 @@ class ZeroDDP(ColoDDP):
|
||||
keep_vars: bool = False,
|
||||
max_shard_size: int = 1024,
|
||||
only_rank_0: bool = True,
|
||||
dtype: torch.dtype = torch.float16) -> Iterator[OrderedDict]:
|
||||
dtype: torch.dtype = torch.float16) -> Iterator[Tuple[OrderedDict, int]]:
|
||||
"""Returns dictionaries containing a whole state of the module one by one. The max size of dictionary shard is specified by ``max_shard_size``.
|
||||
|
||||
Both parameters and persistent buffers (e.g. running averages) are included.
|
||||
@@ -644,9 +671,9 @@ class ZeroDDP(ColoDDP):
|
||||
gathered_param_buffer.update(self._get_chunk_to_save_data(chunk, only_rank_0, dtype))
|
||||
gathered_param = gathered_param_buffer.pop(fp32_param)
|
||||
|
||||
block = sharder.append(prefix + name, gathered_param)
|
||||
block, block_size = sharder.append(prefix + name, gathered_param)
|
||||
if block is not None:
|
||||
yield block
|
||||
yield block, block_size
|
||||
|
||||
del fp16_to_fp32
|
||||
del gathered_param_buffer
|
||||
@@ -655,19 +682,19 @@ class ZeroDDP(ColoDDP):
|
||||
for name, buf in self.named_buffers():
|
||||
if buf is not None and name not in self._non_persistent_buffers_set:
|
||||
buffer = buf if keep_vars else buf.detach()
|
||||
block = sharder.append(prefix + name, buffer)
|
||||
block, block_size = sharder.append(prefix + name, buffer)
|
||||
if block is not None:
|
||||
yield block
|
||||
yield block, block_size
|
||||
# save extra states
|
||||
extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
|
||||
if getattr(self.__class__, "get_extra_state",
|
||||
torch.nn.Module.get_extra_state) is not torch.nn.Module.get_extra_state:
|
||||
extra_state = self.get_extra_state()
|
||||
block = sharder.append(extra_state_key, extra_state)
|
||||
block, block_size = sharder.append(extra_state_key, extra_state)
|
||||
if block is not None:
|
||||
yield block
|
||||
yield block, block_size
|
||||
|
||||
yield sharder.current_block
|
||||
yield sharder.current_block, sharder.current_block_size
|
||||
|
||||
|
||||
class _StateDictSharder:
|
||||
@@ -677,16 +704,18 @@ class _StateDictSharder:
|
||||
self.current_block = OrderedDict()
|
||||
self.current_block_size = 0
|
||||
|
||||
def append(self, name: str, tensor: torch.Tensor) -> Optional[OrderedDict]:
|
||||
def append(self, name: str, tensor: torch.Tensor) -> Tuple[Optional[OrderedDict], int]:
|
||||
tensor_size = calculate_tensor_size(tensor)
|
||||
ret_block = None
|
||||
ret_block_size = 0
|
||||
if self.current_block_size + tensor_size > self.max_shard_size:
|
||||
ret_block = self.current_block
|
||||
ret_block_size = self.current_block_size
|
||||
self.current_block = OrderedDict()
|
||||
self.current_block_size = 0
|
||||
self.current_block[name] = tensor
|
||||
self.current_block_size += tensor_size
|
||||
return ret_block
|
||||
return ret_block, ret_block_size
|
||||
|
||||
|
||||
class GeminiDDP(ZeroDDP):
|
||||
|
Reference in New Issue
Block a user