[gemini] add get static torch model (#2356)

This commit is contained in:
HELSON
2023-01-06 13:41:19 +08:00
committed by GitHub
parent 7a332b1734
commit 48d33b1b17
4 changed files with 164 additions and 118 deletions

View File

@@ -389,19 +389,6 @@ class ZeroDDP(ColoDDP):
del temp_chunk
return param_to_save_data
def torch_named_parameters(self):
"""
get named_parameters() of self.module. It is used the same of PyTorch param and returns the real param.data payload.
It works the same as torch.Module named_parameters
"""
params_list = [p for p in self.parameters(recurse=True)]
param_to_save_data = self._get_param_to_save_data(params_list, False)
for (name, _), p in zip(self.named_parameters(recurse=True), params_list):
if p is not None:
assert p in param_to_save_data, "Parameter '{}' is neglected in the chunk list".format(name)
record_parameter = param_to_save_data[p]
yield name, record_parameter
def _save_to_state_dict(self, destination, prefix, keep_vars, only_rank_0=True):
r"""Saves module state to `destination` dictionary, containing a state
of the module, but not its descendants. This is called on every
@@ -418,6 +405,7 @@ class ZeroDDP(ColoDDP):
assert keep_vars is False, "`state_dict` with parameter, `keep_vars=True`, is not supported now."
param_to_save_data = self._get_param_to_save_data(self.fp32_params, only_rank_0)
# TODO: (HELSON) deal with ddp ignored parameters
for (name, p), fp32_p in zip(self.named_parameters(), self.fp32_params):
if p is not None:
assert fp32_p in param_to_save_data, "Parameter '{}' is neglected in the chunk list".format(name)

View File

@@ -1,5 +1,10 @@
from collections import OrderedDict
from copy import copy
from typing import Optional, Set
import torch
import torch.distributed as dist
import torch.nn as nn
from colossalai.gemini.chunk import Chunk
from colossalai.utils import get_current_device
@@ -21,30 +26,88 @@ def get_temp_total_chunk_on_cuda(chunk: Chunk):
return total_temp
# TODO() not work for module where two params share the same tensor.
def _add_param(model, name, param):
name_list = name.split('.')
module = model._modules[name_list[0]]
for i in range(1, len(name_list) - 1):
module = module._modules[name_list[i]]
module._parameters[name_list[-1]] = param
def _get_dfs_module_list(module: nn.Module, memo: Optional[Set[nn.Module]] = None, prefix: str = ''):
"""Get a dfs module list of the given module. Its order is same as the order of creations of modules.
"""
if memo is None:
memo = set()
if module not in memo:
for name, submodule in module._modules.items():
if submodule is None:
continue
submodule_prefix = prefix + ('.' if prefix else '') + name
for m in _get_dfs_module_list(submodule, memo, submodule_prefix):
yield m
memo.add(module)
yield prefix, module
def convert_to_torch_module(gemini_ddp_model: 'GeminiDDP') -> torch.nn.Module:
"""convert_to_torch_module
def _get_shallow_copy_model(model: nn.Module):
"""Get a shallow copy of the given model. Each submodule is different from the original submodule.
But the new submodule and the old submodule share all attributes.
"""
name_to_module = dict()
for name, module in _get_dfs_module_list(model):
new_module = copy(module)
new_module._modules = OrderedDict()
for subname, submodule in module._modules.items():
if submodule is None:
continue
full_name = name + ('.' if name else '') + subname
setattr(new_module, subname, name_to_module[full_name])
name_to_module[name] = new_module
return name_to_module['']
def get_static_torch_model(gemini_ddp_model,
device=torch.device("cpu"),
dtype=torch.float32,
only_rank_0=True) -> torch.nn.Module:
"""Get a static torch.nn.Module model from the given GeminiDDP module.
You should notice that the original GeminiDDP model is not modified.
Thus, you can use the original model in further training.
But you should not use the returned torch model to train, this can cause unexpected errors.
Args:
gemini_ddp_model (GeminiDDP): a gemini ddp model
device (torch.device): the device of the final torch model
dtype (torch.dtype): the dtype of the final torch model
only_rank_0 (bool): if True, only rank0 has the coverted torch model
Returns:
torch.nn.Module: a torch model contains the params of gemini_ddp_model
torch.nn.Module: a static torch model used for saving checkpoints or numeric checks
"""
from colossalai.nn.parallel import GeminiDDP
assert isinstance(gemini_ddp_model, GeminiDDP)
module = gemini_ddp_model.module
# replace ColoTensor to torch.nn.Tensor in module
for n, p in gemini_ddp_model.torch_named_parameters():
_add_param(module, n, p)
state_dict = gemini_ddp_model.state_dict(only_rank_0=only_rank_0)
colo_model = gemini_ddp_model.module
torch_model = _get_shallow_copy_model(colo_model)
return module
if not only_rank_0 or dist.get_rank() == 0:
# record the mapping relationship between colo parameters and torch parameters
colo_to_torch = dict()
for (name, colo_module), (_, torch_module) in \
zip(_get_dfs_module_list(colo_model), _get_dfs_module_list(torch_model)):
# clean the parameter list of the new torch module
torch_module._parameters = OrderedDict()
for sufix_param_name, param in colo_module.named_parameters(recurse=False):
# get the full name of the parameter
full_param_name = name + ('.' if name else '') + sufix_param_name
if full_param_name not in state_dict:
# this means the parameter is shared by multiple modules
# we should use colo_to_torch to get the torch parameter created before
assert param in colo_to_torch, f"can not find parameter `{full_param_name}` in the GeminiDDP module"
torch_param = colo_to_torch[param]
else:
# we meet the parameter the first time, just use the state dict to get the data
state_param = state_dict[full_param_name]
torch_param = torch.nn.Parameter(state_param.data.to(device=device, dtype=dtype))
colo_to_torch[param] = torch_param
setattr(torch_module, sufix_param_name, torch_param)
dist.barrier()
return torch_model