mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-24 19:17:30 +00:00
[gemini] add get static torch model (#2356)
This commit is contained in:
@@ -389,19 +389,6 @@ class ZeroDDP(ColoDDP):
|
||||
del temp_chunk
|
||||
return param_to_save_data
|
||||
|
||||
def torch_named_parameters(self):
|
||||
"""
|
||||
get named_parameters() of self.module. It is used the same of PyTorch param and returns the real param.data payload.
|
||||
It works the same as torch.Module named_parameters
|
||||
"""
|
||||
params_list = [p for p in self.parameters(recurse=True)]
|
||||
param_to_save_data = self._get_param_to_save_data(params_list, False)
|
||||
for (name, _), p in zip(self.named_parameters(recurse=True), params_list):
|
||||
if p is not None:
|
||||
assert p in param_to_save_data, "Parameter '{}' is neglected in the chunk list".format(name)
|
||||
record_parameter = param_to_save_data[p]
|
||||
yield name, record_parameter
|
||||
|
||||
def _save_to_state_dict(self, destination, prefix, keep_vars, only_rank_0=True):
|
||||
r"""Saves module state to `destination` dictionary, containing a state
|
||||
of the module, but not its descendants. This is called on every
|
||||
@@ -418,6 +405,7 @@ class ZeroDDP(ColoDDP):
|
||||
assert keep_vars is False, "`state_dict` with parameter, `keep_vars=True`, is not supported now."
|
||||
|
||||
param_to_save_data = self._get_param_to_save_data(self.fp32_params, only_rank_0)
|
||||
# TODO: (HELSON) deal with ddp ignored parameters
|
||||
for (name, p), fp32_p in zip(self.named_parameters(), self.fp32_params):
|
||||
if p is not None:
|
||||
assert fp32_p in param_to_save_data, "Parameter '{}' is neglected in the chunk list".format(name)
|
||||
|
@@ -1,5 +1,10 @@
|
||||
from collections import OrderedDict
|
||||
from copy import copy
|
||||
from typing import Optional, Set
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.gemini.chunk import Chunk
|
||||
from colossalai.utils import get_current_device
|
||||
@@ -21,30 +26,88 @@ def get_temp_total_chunk_on_cuda(chunk: Chunk):
|
||||
return total_temp
|
||||
|
||||
|
||||
# TODO() not work for module where two params share the same tensor.
|
||||
def _add_param(model, name, param):
|
||||
name_list = name.split('.')
|
||||
module = model._modules[name_list[0]]
|
||||
for i in range(1, len(name_list) - 1):
|
||||
module = module._modules[name_list[i]]
|
||||
module._parameters[name_list[-1]] = param
|
||||
def _get_dfs_module_list(module: nn.Module, memo: Optional[Set[nn.Module]] = None, prefix: str = ''):
|
||||
"""Get a dfs module list of the given module. Its order is same as the order of creations of modules.
|
||||
"""
|
||||
if memo is None:
|
||||
memo = set()
|
||||
if module not in memo:
|
||||
for name, submodule in module._modules.items():
|
||||
if submodule is None:
|
||||
continue
|
||||
submodule_prefix = prefix + ('.' if prefix else '') + name
|
||||
for m in _get_dfs_module_list(submodule, memo, submodule_prefix):
|
||||
yield m
|
||||
|
||||
memo.add(module)
|
||||
yield prefix, module
|
||||
|
||||
|
||||
def convert_to_torch_module(gemini_ddp_model: 'GeminiDDP') -> torch.nn.Module:
|
||||
"""convert_to_torch_module
|
||||
def _get_shallow_copy_model(model: nn.Module):
|
||||
"""Get a shallow copy of the given model. Each submodule is different from the original submodule.
|
||||
But the new submodule and the old submodule share all attributes.
|
||||
"""
|
||||
name_to_module = dict()
|
||||
for name, module in _get_dfs_module_list(model):
|
||||
new_module = copy(module)
|
||||
new_module._modules = OrderedDict()
|
||||
for subname, submodule in module._modules.items():
|
||||
if submodule is None:
|
||||
continue
|
||||
full_name = name + ('.' if name else '') + subname
|
||||
setattr(new_module, subname, name_to_module[full_name])
|
||||
name_to_module[name] = new_module
|
||||
return name_to_module['']
|
||||
|
||||
|
||||
def get_static_torch_model(gemini_ddp_model,
|
||||
device=torch.device("cpu"),
|
||||
dtype=torch.float32,
|
||||
only_rank_0=True) -> torch.nn.Module:
|
||||
"""Get a static torch.nn.Module model from the given GeminiDDP module.
|
||||
You should notice that the original GeminiDDP model is not modified.
|
||||
Thus, you can use the original model in further training.
|
||||
But you should not use the returned torch model to train, this can cause unexpected errors.
|
||||
|
||||
Args:
|
||||
gemini_ddp_model (GeminiDDP): a gemini ddp model
|
||||
device (torch.device): the device of the final torch model
|
||||
dtype (torch.dtype): the dtype of the final torch model
|
||||
only_rank_0 (bool): if True, only rank0 has the coverted torch model
|
||||
|
||||
Returns:
|
||||
torch.nn.Module: a torch model contains the params of gemini_ddp_model
|
||||
torch.nn.Module: a static torch model used for saving checkpoints or numeric checks
|
||||
"""
|
||||
from colossalai.nn.parallel import GeminiDDP
|
||||
assert isinstance(gemini_ddp_model, GeminiDDP)
|
||||
module = gemini_ddp_model.module
|
||||
|
||||
# replace ColoTensor to torch.nn.Tensor in module
|
||||
for n, p in gemini_ddp_model.torch_named_parameters():
|
||||
_add_param(module, n, p)
|
||||
state_dict = gemini_ddp_model.state_dict(only_rank_0=only_rank_0)
|
||||
colo_model = gemini_ddp_model.module
|
||||
torch_model = _get_shallow_copy_model(colo_model)
|
||||
|
||||
return module
|
||||
if not only_rank_0 or dist.get_rank() == 0:
|
||||
# record the mapping relationship between colo parameters and torch parameters
|
||||
colo_to_torch = dict()
|
||||
for (name, colo_module), (_, torch_module) in \
|
||||
zip(_get_dfs_module_list(colo_model), _get_dfs_module_list(torch_model)):
|
||||
# clean the parameter list of the new torch module
|
||||
torch_module._parameters = OrderedDict()
|
||||
for sufix_param_name, param in colo_module.named_parameters(recurse=False):
|
||||
# get the full name of the parameter
|
||||
full_param_name = name + ('.' if name else '') + sufix_param_name
|
||||
|
||||
if full_param_name not in state_dict:
|
||||
# this means the parameter is shared by multiple modules
|
||||
# we should use colo_to_torch to get the torch parameter created before
|
||||
assert param in colo_to_torch, f"can not find parameter `{full_param_name}` in the GeminiDDP module"
|
||||
torch_param = colo_to_torch[param]
|
||||
else:
|
||||
# we meet the parameter the first time, just use the state dict to get the data
|
||||
state_param = state_dict[full_param_name]
|
||||
torch_param = torch.nn.Parameter(state_param.data.to(device=device, dtype=dtype))
|
||||
colo_to_torch[param] = torch_param
|
||||
|
||||
setattr(torch_module, sufix_param_name, torch_param)
|
||||
dist.barrier()
|
||||
|
||||
return torch_model
|
||||
|
Reference in New Issue
Block a user