mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 19:40:28 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -10,7 +10,7 @@ import torch.nn as nn
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.distributed.distributed_c10d import _get_default_group
|
||||
|
||||
from colossalai.checkpoint_io.utils import StateDictSharder, calculate_tensor_size
|
||||
from colossalai.checkpoint_io.utils import StateDictSharder
|
||||
from colossalai.interface import ModelWrapper
|
||||
from colossalai.lazy import LazyTensor
|
||||
from colossalai.logging import get_dist_logger
|
||||
@@ -27,10 +27,10 @@ from .utils import get_temp_total_chunk_on_cuda
|
||||
try:
|
||||
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, _IncompatibleKeys
|
||||
except ImportError:
|
||||
_EXTRA_STATE_KEY_SUFFIX = '_extra_state'
|
||||
_EXTRA_STATE_KEY_SUFFIX = "_extra_state"
|
||||
|
||||
__all__ = [
|
||||
'GeminiDDP',
|
||||
"GeminiDDP",
|
||||
]
|
||||
|
||||
|
||||
@@ -54,27 +54,28 @@ class GeminiDDP(ModelWrapper):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
module: torch.nn.Module,
|
||||
chunk_config_dict: Optional[dict] = None,
|
||||
chunk_init_device: torch.device = torch.device('cpu'),
|
||||
placement_policy: str = "static",
|
||||
shard_param_frac: float = 1.0, # only for static placement
|
||||
offload_optim_frac: float = 0.0, # only for static placement
|
||||
offload_param_frac: float = 0.0, # only for static placement
|
||||
warmup_non_model_data_ratio: float = 0.8, # only for auto placement
|
||||
steady_cuda_cap_ratio: float = 0.9, # only for auto placement
|
||||
search_range_m: int = 32, # chunk search options
|
||||
hidden_dim: Optional[int] = None, # chunk search options
|
||||
min_chunk_size_m: float = 32, # chunk search options
|
||||
pin_memory: bool = False,
|
||||
force_outputs_fp32: bool = False,
|
||||
strict_ddp_mode: bool = False,
|
||||
scatter_after_inference: bool = True,
|
||||
mixed_precision: torch.dtype = torch.float16,
|
||||
process_group: Optional[ProcessGroup] = None,
|
||||
memstats: Optional[MemStats] = None, # genimi memory stats
|
||||
verbose: bool = False) -> None:
|
||||
self,
|
||||
module: torch.nn.Module,
|
||||
chunk_config_dict: Optional[dict] = None,
|
||||
chunk_init_device: torch.device = torch.device("cpu"),
|
||||
placement_policy: str = "static",
|
||||
shard_param_frac: float = 1.0, # only for static placement
|
||||
offload_optim_frac: float = 0.0, # only for static placement
|
||||
offload_param_frac: float = 0.0, # only for static placement
|
||||
warmup_non_model_data_ratio: float = 0.8, # only for auto placement
|
||||
steady_cuda_cap_ratio: float = 0.9, # only for auto placement
|
||||
search_range_m: int = 32, # chunk search options
|
||||
hidden_dim: Optional[int] = None, # chunk search options
|
||||
min_chunk_size_m: float = 32, # chunk search options
|
||||
pin_memory: bool = False,
|
||||
force_outputs_fp32: bool = False,
|
||||
strict_ddp_mode: bool = False,
|
||||
scatter_after_inference: bool = True,
|
||||
mixed_precision: torch.dtype = torch.float16,
|
||||
process_group: Optional[ProcessGroup] = None,
|
||||
memstats: Optional[MemStats] = None, # genimi memory stats
|
||||
verbose: bool = False,
|
||||
) -> None:
|
||||
assert mixed_precision in (torch.float16, torch.bfloat16)
|
||||
if chunk_config_dict is not None:
|
||||
self.chunk_manager = ChunkManager(chunk_config_dict, chunk_init_device)
|
||||
@@ -82,22 +83,26 @@ class GeminiDDP(ModelWrapper):
|
||||
# some ugly hotfix for the compatibility with Lightning
|
||||
if search_range_m is None:
|
||||
search_range_m = 32
|
||||
self.chunk_manager = init_chunk_manager(model=module,
|
||||
init_device=chunk_init_device,
|
||||
hidden_dim=hidden_dim,
|
||||
search_range_m=search_range_m,
|
||||
min_chunk_size_m=min_chunk_size_m,
|
||||
strict_ddp_flag=strict_ddp_mode,
|
||||
process_group=process_group,
|
||||
verbose=verbose)
|
||||
self.gemini_manager = GeminiManager(placement_policy,
|
||||
self.chunk_manager,
|
||||
memstats,
|
||||
shard_param_frac=shard_param_frac,
|
||||
offload_optim_frac=offload_optim_frac,
|
||||
offload_param_frac=offload_param_frac,
|
||||
warmup_non_model_data_ratio=warmup_non_model_data_ratio,
|
||||
steady_cuda_cap_ratio=steady_cuda_cap_ratio)
|
||||
self.chunk_manager = init_chunk_manager(
|
||||
model=module,
|
||||
init_device=chunk_init_device,
|
||||
hidden_dim=hidden_dim,
|
||||
search_range_m=search_range_m,
|
||||
min_chunk_size_m=min_chunk_size_m,
|
||||
strict_ddp_flag=strict_ddp_mode,
|
||||
process_group=process_group,
|
||||
verbose=verbose,
|
||||
)
|
||||
self.gemini_manager = GeminiManager(
|
||||
placement_policy,
|
||||
self.chunk_manager,
|
||||
memstats,
|
||||
shard_param_frac=shard_param_frac,
|
||||
offload_optim_frac=offload_optim_frac,
|
||||
offload_param_frac=offload_param_frac,
|
||||
warmup_non_model_data_ratio=warmup_non_model_data_ratio,
|
||||
steady_cuda_cap_ratio=steady_cuda_cap_ratio,
|
||||
)
|
||||
self.force_outputs_fp32 = force_outputs_fp32
|
||||
self.param_op_hook = GeminiZeROHook(self.gemini_manager)
|
||||
self.fp32_params: List[torch.Tensor] = list()
|
||||
@@ -126,13 +131,15 @@ class GeminiDDP(ModelWrapper):
|
||||
self.param2name[param] = name
|
||||
for m_name, m_var in module.named_modules():
|
||||
for p_name, p_var in m_var.named_parameters(recurse=False):
|
||||
param_name = m_name + '.' + p_name if m_name else p_name
|
||||
param_name = m_name + "." + p_name if m_name else p_name
|
||||
self.name2param[param_name] = p_var
|
||||
|
||||
self._init_chunks(param_order=param_order,
|
||||
strict_ddp_mode=strict_ddp_mode,
|
||||
cpu_offload=self.gemini_manager.policy_name != 'cuda',
|
||||
pin_memory=pin_memory)
|
||||
self._init_chunks(
|
||||
param_order=param_order,
|
||||
strict_ddp_mode=strict_ddp_mode,
|
||||
cpu_offload=self.gemini_manager.policy_name != "cuda",
|
||||
pin_memory=pin_memory,
|
||||
)
|
||||
super().__init__(module)
|
||||
self._non_persistent_buffers_set = self._get_non_persistent_buffers_set(module)
|
||||
self._cast_buffers()
|
||||
@@ -146,19 +153,18 @@ class GeminiDDP(ModelWrapper):
|
||||
def parameters(self, recurse: bool = True):
|
||||
return self.module.parameters(recurse)
|
||||
|
||||
def named_parameters(self, prefix: str = '', recurse: bool = True):
|
||||
def named_parameters(self, prefix: str = "", recurse: bool = True):
|
||||
return self.module.named_parameters(prefix, recurse)
|
||||
|
||||
def named_buffers(self, prefix: str = '', recurse: bool = True):
|
||||
def named_buffers(self, prefix: str = "", recurse: bool = True):
|
||||
return self.module.named_buffers(prefix, recurse)
|
||||
|
||||
def named_children(self):
|
||||
return self.module.named_children()
|
||||
|
||||
def named_modules(self,
|
||||
memo: Optional[Set[torch.nn.Module]] = None,
|
||||
prefix: str = '',
|
||||
remove_duplicate: bool = True):
|
||||
def named_modules(
|
||||
self, memo: Optional[Set[torch.nn.Module]] = None, prefix: str = "", remove_duplicate: bool = True
|
||||
):
|
||||
return self.module.named_modules(memo, prefix, remove_duplicate)
|
||||
|
||||
@staticmethod
|
||||
@@ -184,11 +190,9 @@ class GeminiDDP(ModelWrapper):
|
||||
# as save/load state dict is overwrited, only return self
|
||||
return self
|
||||
|
||||
def _get_non_persistent_buffers_set(self,
|
||||
module,
|
||||
memo: Optional[Set[nn.Module]] = None,
|
||||
prefix: str = '',
|
||||
remove_duplicate: bool = True):
|
||||
def _get_non_persistent_buffers_set(
|
||||
self, module, memo: Optional[Set[nn.Module]] = None, prefix: str = "", remove_duplicate: bool = True
|
||||
):
|
||||
r"""
|
||||
Args:
|
||||
memo: a memo to store the set of modules already added to the result
|
||||
@@ -204,19 +208,20 @@ class GeminiDDP(ModelWrapper):
|
||||
if remove_duplicate:
|
||||
memo.add(module)
|
||||
self_non_persistent_set = set(
|
||||
map(lambda key: prefix + ('.' if prefix else '') + key, module._non_persistent_buffers_set))
|
||||
map(lambda key: prefix + ("." if prefix else "") + key, module._non_persistent_buffers_set)
|
||||
)
|
||||
for name, sub_module in module._modules.items():
|
||||
if sub_module is None:
|
||||
continue
|
||||
submodule_prefix = prefix + ('.' if prefix else '') + name
|
||||
child_non_persistent_set = self._get_non_persistent_buffers_set(sub_module, memo, submodule_prefix,
|
||||
remove_duplicate)
|
||||
submodule_prefix = prefix + ("." if prefix else "") + name
|
||||
child_non_persistent_set = self._get_non_persistent_buffers_set(
|
||||
sub_module, memo, submodule_prefix, remove_duplicate
|
||||
)
|
||||
self_non_persistent_set = set.union(self_non_persistent_set, child_non_persistent_set)
|
||||
return self_non_persistent_set
|
||||
|
||||
def _post_forward(self):
|
||||
"""This function is only triggered for inference.
|
||||
"""
|
||||
"""This function is only triggered for inference."""
|
||||
access_list = list(self.chunk_manager.accessed_chunks)
|
||||
# we need to scatter all accessed chunks and move them to their original places
|
||||
for chunk in access_list:
|
||||
@@ -233,7 +238,8 @@ class GeminiDDP(ModelWrapper):
|
||||
# check whether we are in a inference mode
|
||||
grad_flag = torch.is_grad_enabled()
|
||||
if not grad_flag:
|
||||
assert not self.gemini_manager.need_warmup or not self.gemini_manager.is_warmup(
|
||||
assert (
|
||||
not self.gemini_manager.need_warmup or not self.gemini_manager.is_warmup()
|
||||
), "You should run a completed iteration as your warmup iter"
|
||||
|
||||
args, kwargs = _cast_float(args, self.mixed_precision), _cast_float(kwargs, self.mixed_precision)
|
||||
@@ -250,8 +256,7 @@ class GeminiDDP(ModelWrapper):
|
||||
return outputs
|
||||
|
||||
def _inference_forward(self, *args, **kwargs):
|
||||
"""This function is only triggered for inference.
|
||||
"""
|
||||
"""This function is only triggered for inference."""
|
||||
fwd_ctx = ColoParamOpHookManager.use_hooks(self.param_op_hook)
|
||||
if not self.scatter_after_inference:
|
||||
# gather all chunks
|
||||
@@ -287,12 +292,14 @@ class GeminiDDP(ModelWrapper):
|
||||
if not is_ddp_ignored(param) and not getattr(param, "_gemini_reduced"):
|
||||
error_params.append(self.param2name[param])
|
||||
error_str = "\n\t".join(error_params)
|
||||
raise RuntimeError("ZERO DDP error: the synchronization of gradients doesn't exit properly.",
|
||||
"The most possible reason is that the model is not compatible with GeminiDDP.\n",
|
||||
f"{error_str}")
|
||||
raise RuntimeError(
|
||||
"ZERO DDP error: the synchronization of gradients doesn't exit properly.",
|
||||
"The most possible reason is that the model is not compatible with GeminiDDP.\n",
|
||||
f"{error_str}",
|
||||
)
|
||||
self._setup_grads_ptr()
|
||||
self._logger.debug(
|
||||
f'comp cuda demand time: {self.gemini_manager._comp_cuda_demand_time}, layout time: {self.gemini_manager._layout_time}, evict time: {self.gemini_manager._evict_time}, CPU->CUDA vol: {self.gemini_manager._h2d_volume}B, CUDA->CPU vol: {self.gemini_manager._d2h_volume}'
|
||||
f"comp cuda demand time: {self.gemini_manager._comp_cuda_demand_time}, layout time: {self.gemini_manager._layout_time}, evict time: {self.gemini_manager._evict_time}, CPU->CUDA vol: {self.gemini_manager._h2d_volume}B, CUDA->CPU vol: {self.gemini_manager._d2h_volume}"
|
||||
)
|
||||
self.gemini_manager.post_iter()
|
||||
|
||||
@@ -314,8 +321,10 @@ class GeminiDDP(ModelWrapper):
|
||||
with torch._C.DisableTorchFunction():
|
||||
chunk = self.chunk_manager.get_chunk(p)
|
||||
if chunk.tensors_info[p].state != TensorState.HOLD_AFTER_BWD:
|
||||
raise RuntimeError(f"Parameter `{self.param2name[p]}` failed at the gradient reduction. "
|
||||
"Some unsupported torch function is operated upon this parameter.")
|
||||
raise RuntimeError(
|
||||
f"Parameter `{self.param2name[p]}` failed at the gradient reduction. "
|
||||
"Some unsupported torch function is operated upon this parameter."
|
||||
)
|
||||
self.chunk_manager.trans_tensor_state(p, TensorState.READY_FOR_REDUCE)
|
||||
chunk.copy_tensor_to_chunk_slice(p, grad)
|
||||
reduced = self.chunk_manager.reduce_chunk(chunk)
|
||||
@@ -339,12 +348,9 @@ class GeminiDDP(ModelWrapper):
|
||||
for tensor in chunk.get_tensors():
|
||||
self.grads_device[tensor] = device
|
||||
|
||||
def state_dict(self,
|
||||
destination=None,
|
||||
prefix='',
|
||||
keep_vars=False,
|
||||
only_rank_0: bool = True,
|
||||
dtype: torch.dtype = torch.float16):
|
||||
def state_dict(
|
||||
self, destination=None, prefix="", keep_vars=False, only_rank_0: bool = True, dtype: torch.dtype = torch.float16
|
||||
):
|
||||
"""Returns a dictionary containing a whole state of the module.
|
||||
|
||||
Both parameters and persistent buffers (e.g. running averages) are included.
|
||||
@@ -391,7 +397,7 @@ class GeminiDDP(ModelWrapper):
|
||||
record_tensor = torch.empty([0])
|
||||
record_flag = (not only_rank_0) | (dist.get_rank(chunk.torch_pg) == 0)
|
||||
if record_flag:
|
||||
record_tensor = temp_chunk[tensor_info.offset:tensor_info.end].view(tensor.shape).cpu()
|
||||
record_tensor = temp_chunk[tensor_info.offset : tensor_info.end].view(tensor.shape).cpu()
|
||||
|
||||
assert tensor not in chunk_to_save_data
|
||||
chunk_to_save_data[tensor] = record_tensor
|
||||
@@ -399,8 +405,9 @@ class GeminiDDP(ModelWrapper):
|
||||
del temp_chunk
|
||||
return chunk_to_save_data
|
||||
|
||||
def _get_param_to_save_data(self, param_list: List[torch.nn.Parameter], only_rank_0: bool,
|
||||
dtype: torch.dtype) -> Dict:
|
||||
def _get_param_to_save_data(
|
||||
self, param_list: List[torch.nn.Parameter], only_rank_0: bool, dtype: torch.dtype
|
||||
) -> Dict:
|
||||
"""
|
||||
get param content from chunks.
|
||||
|
||||
@@ -459,11 +466,13 @@ class GeminiDDP(ModelWrapper):
|
||||
destination[prefix + name] = buf if keep_vars else buf.detach()
|
||||
# save extra states
|
||||
extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
|
||||
if getattr(self.__class__, "get_extra_state",
|
||||
torch.nn.Module.get_extra_state) is not torch.nn.Module.get_extra_state:
|
||||
if (
|
||||
getattr(self.__class__, "get_extra_state", torch.nn.Module.get_extra_state)
|
||||
is not torch.nn.Module.get_extra_state
|
||||
):
|
||||
destination[extra_state_key] = self.get_extra_state()
|
||||
|
||||
def load_state_dict(self, state_dict: 'OrderedDict[str, torch.Tensor]', strict: bool = True):
|
||||
def load_state_dict(self, state_dict: "OrderedDict[str, torch.Tensor]", strict: bool = True):
|
||||
r"""Copies parameters and buffers from :attr:`state_dict` into
|
||||
this module and its descendants. If :attr:`strict` is ``True``, then
|
||||
the keys of :attr:`state_dict` must exactly match the keys returned
|
||||
@@ -491,32 +500,38 @@ class GeminiDDP(ModelWrapper):
|
||||
error_msgs: List[str] = []
|
||||
|
||||
# copy state_dict so _load_from_state_dict can modify it
|
||||
metadata = getattr(state_dict, '_metadata', None)
|
||||
metadata = getattr(state_dict, "_metadata", None)
|
||||
state_dict = state_dict.copy()
|
||||
if metadata is not None:
|
||||
# mypy isn't aware that "_metadata" exists in state_dict
|
||||
state_dict._metadata = metadata # type: ignore[attr-defined]
|
||||
state_dict._metadata = metadata # type: ignore[attr-defined]
|
||||
|
||||
prefix = ''
|
||||
prefix = ""
|
||||
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
|
||||
self._load_from_state_dict(state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
|
||||
|
||||
if strict:
|
||||
if len(unexpected_keys) > 0:
|
||||
error_msgs.insert(
|
||||
0, 'Unexpected key(s) in state_dict: {}. '.format(', '.join(
|
||||
'"{}"'.format(k) for k in unexpected_keys)))
|
||||
0,
|
||||
"Unexpected key(s) in state_dict: {}. ".format(
|
||||
", ".join('"{}"'.format(k) for k in unexpected_keys)
|
||||
),
|
||||
)
|
||||
if len(missing_keys) > 0:
|
||||
error_msgs.insert(
|
||||
0, 'Missing key(s) in state_dict: {}. '.format(', '.join('"{}"'.format(k) for k in missing_keys)))
|
||||
0, "Missing key(s) in state_dict: {}. ".format(", ".join('"{}"'.format(k) for k in missing_keys))
|
||||
)
|
||||
|
||||
if len(error_msgs) > 0:
|
||||
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
|
||||
self.__class__.__name__, "\n\t".join(error_msgs)))
|
||||
raise RuntimeError(
|
||||
"Error(s) in loading state_dict for {}:\n\t{}".format(self.__class__.__name__, "\n\t".join(error_msgs))
|
||||
)
|
||||
return _IncompatibleKeys(missing_keys, unexpected_keys)
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
|
||||
error_msgs):
|
||||
def _load_from_state_dict(
|
||||
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
||||
):
|
||||
r"""Copies parameters and buffers from :attr:`state_dict` into only
|
||||
this module, but not its descendants. This is called on every submodule
|
||||
in :meth:`~torch.nn.Module.load_state_dict`. Metadata saved for this
|
||||
@@ -564,19 +579,21 @@ class GeminiDDP(ModelWrapper):
|
||||
input_param = input_param[0]
|
||||
if input_param.shape != dest_tensor.shape:
|
||||
# local shape should match the one in checkpoint
|
||||
error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, '
|
||||
'the shape in current model is {}.'.format(state_key, input_param.shape,
|
||||
dest_tensor.shape))
|
||||
error_msgs.append(
|
||||
"size mismatch for {}: copying a param with shape {} from checkpoint, "
|
||||
"the shape in current model is {}.".format(state_key, input_param.shape, dest_tensor.shape)
|
||||
)
|
||||
return
|
||||
try:
|
||||
with torch.no_grad():
|
||||
copy_func(input_param)
|
||||
except Exception as ex:
|
||||
error_msgs.append('While copying the parameter named "{}", '
|
||||
'whose dimensions in the model are {} and '
|
||||
'whose dimensions in the checkpoint are {}, '
|
||||
'an exception occurred : {}.'.format(state_key, dest_tensor.size(),
|
||||
input_param.size(), ex.args))
|
||||
error_msgs.append(
|
||||
'While copying the parameter named "{}", '
|
||||
"whose dimensions in the model are {} and "
|
||||
"whose dimensions in the checkpoint are {}, "
|
||||
"an exception occurred : {}.".format(state_key, dest_tensor.size(), input_param.size(), ex.args)
|
||||
)
|
||||
elif strict:
|
||||
missing_keys.append(state_key)
|
||||
|
||||
@@ -600,15 +617,15 @@ class GeminiDDP(ModelWrapper):
|
||||
|
||||
for tensor, tensor_info in chunk.tensors_info.items():
|
||||
parameter_name = fp32_to_name[tensor]
|
||||
parameter_slice = temp_chunk[tensor_info.offset:tensor_info.end]
|
||||
parameter_slice = temp_chunk[tensor_info.offset : tensor_info.end]
|
||||
load(parameter_name, tensor, partial(load_fp32_parameter, parameter_slice))
|
||||
|
||||
if chunk.is_gathered:
|
||||
chunk.cuda_global_chunk.copy_(temp_chunk)
|
||||
elif chunk.cuda_shard is not None:
|
||||
chunk.cuda_shard.copy_(temp_chunk[chunk.shard_begin:chunk.shard_end])
|
||||
chunk.cuda_shard.copy_(temp_chunk[chunk.shard_begin : chunk.shard_end])
|
||||
else:
|
||||
chunk.cpu_shard.copy_(temp_chunk[chunk.shard_begin:chunk.shard_end])
|
||||
chunk.cpu_shard.copy_(temp_chunk[chunk.shard_begin : chunk.shard_end])
|
||||
|
||||
del temp_chunk
|
||||
|
||||
@@ -622,8 +639,10 @@ class GeminiDDP(ModelWrapper):
|
||||
load(name, buf, buf.copy_)
|
||||
|
||||
extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
|
||||
if getattr(self.__class__, "set_extra_state",
|
||||
torch.nn.Module.set_extra_state) is not torch.nn.Module.set_extra_state:
|
||||
if (
|
||||
getattr(self.__class__, "set_extra_state", torch.nn.Module.set_extra_state)
|
||||
is not torch.nn.Module.set_extra_state
|
||||
):
|
||||
if extra_state_key in state_dict:
|
||||
self.set_extra_state(state_dict[extra_state_key])
|
||||
elif strict:
|
||||
@@ -634,7 +653,7 @@ class GeminiDDP(ModelWrapper):
|
||||
if strict:
|
||||
for key in state_dict.keys():
|
||||
if key.startswith(prefix) and key != extra_state_key:
|
||||
input_name = key[len(prefix):]
|
||||
input_name = key[len(prefix) :]
|
||||
if input_name not in local_state:
|
||||
unexpected_keys.append(key)
|
||||
|
||||
@@ -659,18 +678,22 @@ class GeminiDDP(ModelWrapper):
|
||||
p.data = p.data.to(self.mixed_precision)
|
||||
|
||||
# register the fp16 parameter and fp32 parameter in the chunk manager
|
||||
self.chunk_manager.register_tensor(tensor=p,
|
||||
group_type='fp16_param',
|
||||
config_key=dp_world_size,
|
||||
process_group=self.dp_process_group,
|
||||
cpu_offload=cpu_offload,
|
||||
pin_memory=pin_memory)
|
||||
self.chunk_manager.register_tensor(tensor=fp32_p,
|
||||
group_type='fp32_param',
|
||||
config_key=dp_world_size,
|
||||
process_group=self.dp_process_group,
|
||||
cpu_offload=cpu_offload,
|
||||
pin_memory=pin_memory)
|
||||
self.chunk_manager.register_tensor(
|
||||
tensor=p,
|
||||
group_type="fp16_param",
|
||||
config_key=dp_world_size,
|
||||
process_group=self.dp_process_group,
|
||||
cpu_offload=cpu_offload,
|
||||
pin_memory=pin_memory,
|
||||
)
|
||||
self.chunk_manager.register_tensor(
|
||||
tensor=fp32_p,
|
||||
group_type="fp32_param",
|
||||
config_key=dp_world_size,
|
||||
process_group=self.dp_process_group,
|
||||
cpu_offload=cpu_offload,
|
||||
pin_memory=pin_memory,
|
||||
)
|
||||
|
||||
self.fp16_params.append(p)
|
||||
self.fp32_params.append(fp32_p)
|
||||
@@ -694,7 +717,7 @@ class GeminiDDP(ModelWrapper):
|
||||
if torch.is_floating_point(buffer):
|
||||
buffer.data = buffer.to(self.mixed_precision)
|
||||
|
||||
def _preprocess_param(self, p: Union[nn.Parameter, ColoParameter, 'LazyTensor']) -> None:
|
||||
def _preprocess_param(self, p: Union[nn.Parameter, ColoParameter, "LazyTensor"]) -> None:
|
||||
"""Convert parameter to ColoParameter in-place.
|
||||
Args:
|
||||
p (Union[nn.Parameter, ColoParameter, LazyTensor]): parameter to be converted
|
||||
@@ -709,12 +732,14 @@ class GeminiDDP(ModelWrapper):
|
||||
p.__class__ = ColoParameter
|
||||
p.__init__(p, requires_grad=requires_grad)
|
||||
|
||||
def state_dict_shard(self,
|
||||
prefix: str = '',
|
||||
keep_vars: bool = False,
|
||||
max_shard_size: int = 1024,
|
||||
only_rank_0: bool = True,
|
||||
dtype: torch.dtype = torch.float16) -> Iterator[Tuple[OrderedDict, int]]:
|
||||
def state_dict_shard(
|
||||
self,
|
||||
prefix: str = "",
|
||||
keep_vars: bool = False,
|
||||
max_shard_size: int = 1024,
|
||||
only_rank_0: bool = True,
|
||||
dtype: torch.dtype = torch.float16,
|
||||
) -> Iterator[Tuple[OrderedDict, int]]:
|
||||
"""Returns dictionaries containing a whole state of the module one by one. The max size of dictionary shard is specified by ``max_shard_size``.
|
||||
|
||||
Both parameters and persistent buffers (e.g. running averages) are included.
|
||||
@@ -770,8 +795,10 @@ class GeminiDDP(ModelWrapper):
|
||||
yield block, block_size
|
||||
# save extra states
|
||||
extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
|
||||
if getattr(self.__class__, "get_extra_state",
|
||||
torch.nn.Module.get_extra_state) is not torch.nn.Module.get_extra_state:
|
||||
if (
|
||||
getattr(self.__class__, "get_extra_state", torch.nn.Module.get_extra_state)
|
||||
is not torch.nn.Module.get_extra_state
|
||||
):
|
||||
extra_state = self.get_extra_state()
|
||||
block, block_size = sharder.append_param(extra_state_key, extra_state)
|
||||
if block is not None:
|
||||
|
Reference in New Issue
Block a user