mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 04:50:17 +00:00
[gemini] support save state dict in shards (#3581)
* [gemini] support state dict shard * [gemini] add test state dict shard * [gemini] polish docstr * [gemini] fix merge * [gemini] polish code
This commit is contained in:
@@ -1,12 +1,13 @@
|
||||
import itertools
|
||||
from collections import OrderedDict
|
||||
from functools import partial
|
||||
from typing import Dict, List, Optional, Union
|
||||
from typing import Dict, Iterator, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.checkpoint_io.utils import calculate_tensor_size
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn.parallel.data_parallel import ColoDDP, _cast_float, free_storage
|
||||
from colossalai.tensor import ProcessGroup as ColoProcessGroup
|
||||
@@ -228,6 +229,32 @@ class ZeroDDP(ColoDDP):
|
||||
destination = hook_result
|
||||
return destination
|
||||
|
||||
def _get_chunk_to_save_data(self, chunk: Chunk, only_rank_0: bool) -> Dict:
|
||||
"""
|
||||
get gathered chunk content.
|
||||
|
||||
Args:
|
||||
chunk (Chunk): a chunk
|
||||
only_rank_0 (bool): whether to only save data on rank 0
|
||||
|
||||
Returns:
|
||||
Dict: a dict whose key is param name and value is param with correct payload
|
||||
"""
|
||||
# save parameters
|
||||
chunk_to_save_data = dict()
|
||||
temp_chunk = get_temp_total_chunk_on_cuda(chunk)
|
||||
for tensor, tensor_info in chunk.tensors_info.items():
|
||||
record_tensor = torch.empty([0])
|
||||
record_flag = (not only_rank_0) | (dist.get_rank(chunk.torch_pg) == 0)
|
||||
if record_flag:
|
||||
record_tensor = temp_chunk[tensor_info.offset:tensor_info.end].view(tensor.shape).cpu()
|
||||
|
||||
assert tensor not in chunk_to_save_data
|
||||
chunk_to_save_data[tensor] = record_tensor
|
||||
|
||||
del temp_chunk
|
||||
return chunk_to_save_data
|
||||
|
||||
def _get_param_to_save_data(self, param_list: List[torch.nn.Parameter], only_rank_0: bool) -> Dict:
|
||||
"""
|
||||
get param content from chunks.
|
||||
@@ -243,18 +270,7 @@ class ZeroDDP(ColoDDP):
|
||||
param_to_save_data = dict()
|
||||
chunk_list = self.chunk_manager.get_chunks(param_list)
|
||||
for chunk in chunk_list:
|
||||
temp_chunk = get_temp_total_chunk_on_cuda(chunk)
|
||||
|
||||
for tensor, tensor_info in chunk.tensors_info.items():
|
||||
record_tensor = torch.empty([0])
|
||||
record_flag = (not only_rank_0) | (dist.get_rank(chunk.torch_pg) == 0)
|
||||
if record_flag:
|
||||
record_tensor = temp_chunk[tensor_info.offset:tensor_info.end].view(tensor.shape).cpu()
|
||||
|
||||
assert tensor not in param_to_save_data
|
||||
param_to_save_data[tensor] = record_tensor
|
||||
|
||||
del temp_chunk
|
||||
param_to_save_data.update(self._get_chunk_to_save_data(chunk, only_rank_0))
|
||||
return param_to_save_data
|
||||
|
||||
def _save_to_state_dict(self, destination, prefix, keep_vars, only_rank_0=True):
|
||||
@@ -554,6 +570,93 @@ class ZeroDDP(ColoDDP):
|
||||
p.__class__ = ColoParameter
|
||||
p.__init__(p, requires_grad=requires_grad)
|
||||
|
||||
def state_dict_shard(self,
|
||||
prefix: str = '',
|
||||
keep_vars: bool = False,
|
||||
max_shard_size: int = 1024,
|
||||
only_rank_0: bool = True) -> Iterator[OrderedDict]:
|
||||
"""Returns dictionaries containing a whole state of the module one by one. The max size of dictionary shard is specified by ``max_shard_size``.
|
||||
|
||||
Both parameters and persistent buffers (e.g. running averages) are included.
|
||||
Keys are corresponding parameter and buffer names.
|
||||
Parameters and buffers set to ``None`` are not included.
|
||||
|
||||
Args:
|
||||
prefix (str, optional): the prefix for parameters and buffers used in this
|
||||
module. Defaults to ''.
|
||||
keep_vars (bool, optional): whether to keep variables. Defaults to False.
|
||||
max_shard_size (int, optional): max size of state dict shard (in MB). Defaults to 1024.
|
||||
only_rank_0 (bool, optional): only get data on rank0. Defaults to True.
|
||||
|
||||
|
||||
Yields:
|
||||
Iterator[OrderedDict]: A generator of state dict shard
|
||||
"""
|
||||
sharder = _StateDictSharder(max_shard_size)
|
||||
|
||||
# get the mapping between copies and fp16 parameters
|
||||
fp16_to_fp32 = dict()
|
||||
for p, fp32_p in zip(self.fp16_params, self.fp32_params):
|
||||
fp16_to_fp32[p] = fp32_p
|
||||
|
||||
# key is fp32 param, and value is gathered param on CPU
|
||||
gathered_param_buffer = dict()
|
||||
for name, param in self.name2param.items():
|
||||
if param is not None:
|
||||
if is_ddp_ignored(param):
|
||||
# deal with ddp ignored parameters
|
||||
gathered_param = param if keep_vars else param.detach()
|
||||
else:
|
||||
fp32_param = fp16_to_fp32[param]
|
||||
if fp32_param not in gathered_param_buffer:
|
||||
chunk = self.chunk_manager.get_chunk(fp32_param)
|
||||
gathered_param_buffer.update(self._get_chunk_to_save_data(chunk, only_rank_0))
|
||||
gathered_param = gathered_param_buffer.pop(fp32_param)
|
||||
|
||||
block = sharder.append(prefix + name, gathered_param)
|
||||
if block is not None:
|
||||
yield block
|
||||
|
||||
del fp16_to_fp32
|
||||
del gathered_param_buffer
|
||||
|
||||
# save all buffers
|
||||
for name, buf in self.named_buffers():
|
||||
if buf is not None and name not in self._non_persistent_buffers_set:
|
||||
buffer = buf if keep_vars else buf.detach()
|
||||
block = sharder.append(prefix + name, buffer)
|
||||
if block is not None:
|
||||
yield block
|
||||
# save extra states
|
||||
extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
|
||||
if getattr(self.__class__, "get_extra_state",
|
||||
torch.nn.Module.get_extra_state) is not torch.nn.Module.get_extra_state:
|
||||
extra_state = self.get_extra_state()
|
||||
block = sharder.append(extra_state_key, extra_state)
|
||||
if block is not None:
|
||||
yield block
|
||||
|
||||
yield sharder.current_block
|
||||
|
||||
|
||||
class _StateDictSharder:
|
||||
|
||||
def __init__(self, max_shard_size: int) -> None:
|
||||
self.max_shard_size = max_shard_size
|
||||
self.current_block = OrderedDict()
|
||||
self.current_block_size = 0
|
||||
|
||||
def append(self, name: str, tensor: torch.Tensor) -> Optional[OrderedDict]:
|
||||
tensor_size = calculate_tensor_size(tensor)
|
||||
ret_block = None
|
||||
if self.current_block_size + tensor_size > self.max_shard_size:
|
||||
ret_block = self.current_block
|
||||
self.current_block = OrderedDict()
|
||||
self.current_block_size = 0
|
||||
self.current_block[name] = tensor
|
||||
self.current_block_size += tensor_size
|
||||
return ret_block
|
||||
|
||||
|
||||
class GeminiDDP(ZeroDDP):
|
||||
|
||||
|
Reference in New Issue
Block a user