mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-11-30 15:55:09 +00:00
[Gemini] fix the convert_to_torch_module bug (#2269)
This commit is contained in:
@@ -360,6 +360,48 @@ class ZeroDDP(ColoDDP):
|
||||
destination = hook_result
|
||||
return destination
|
||||
|
||||
def _get_param_to_save_data(self, param_list: List[torch.nn.Parameter], only_rank_0: bool) -> Dict:
|
||||
"""
|
||||
get param content from chunks.
|
||||
|
||||
Args:
|
||||
param_list (_type_): a list of torch.nn.Parameters
|
||||
only_rank_0 (_type_): _description_
|
||||
|
||||
Returns:
|
||||
Dict: a dict whose key is param name and value is param with correct payload
|
||||
"""
|
||||
# save parameters
|
||||
param_to_save_data = dict()
|
||||
chunk_list = self.chunk_manager.get_chunks(param_list)
|
||||
for chunk in chunk_list:
|
||||
temp_chunk = get_temp_total_chunk_on_cuda(chunk)
|
||||
|
||||
for tensor, tensor_info in chunk.tensors_info.items():
|
||||
record_tensor = torch.empty([0])
|
||||
record_flag = (not only_rank_0) | (dist.get_rank(chunk.torch_pg) == 0)
|
||||
if record_flag:
|
||||
record_tensor = temp_chunk[tensor_info.offset:tensor_info.end].view(tensor.shape).cpu()
|
||||
|
||||
assert tensor not in param_to_save_data
|
||||
param_to_save_data[tensor] = record_tensor
|
||||
|
||||
del temp_chunk
|
||||
return param_to_save_data
|
||||
|
||||
def torch_named_parameters(self):
|
||||
"""
|
||||
get named_parameters() of self.module. It is used the same of PyTorch param and returns the real param.data payload.
|
||||
It works the same as torch.Module named_parameters
|
||||
"""
|
||||
params_list = [p for p in self.parameters(recurse=True)]
|
||||
param_to_save_data = self._get_param_to_save_data(params_list, False)
|
||||
for (name, _), p in zip(self.named_parameters(recurse=True), params_list):
|
||||
if p is not None:
|
||||
assert p in param_to_save_data, "Parameter '{}' is neglected in the chunk list".format(name)
|
||||
record_parameter = param_to_save_data[p]
|
||||
yield name, record_parameter
|
||||
|
||||
def _save_to_state_dict(self, destination, prefix, keep_vars, only_rank_0=True):
|
||||
r"""Saves module state to `destination` dictionary, containing a state
|
||||
of the module, but not its descendants. This is called on every
|
||||
@@ -375,23 +417,7 @@ class ZeroDDP(ColoDDP):
|
||||
"""
|
||||
assert keep_vars is False, "`state_dict` with parameter, `keep_vars=True`, is not supported now."
|
||||
|
||||
# save parameters
|
||||
param_to_save_data = dict()
|
||||
chunk_list = self.chunk_manager.get_chunks(self.fp32_params)
|
||||
for chunk in chunk_list:
|
||||
temp_chunk = get_temp_total_chunk_on_cuda(chunk)
|
||||
|
||||
for tensor, tensor_info in chunk.tensors_info.items():
|
||||
record_tensor = torch.empty([0])
|
||||
record_flag = (not only_rank_0) | (dist.get_rank(chunk.torch_pg) == 0)
|
||||
if record_flag:
|
||||
record_tensor = temp_chunk[tensor_info.offset:tensor_info.end].view(tensor.shape).cpu()
|
||||
|
||||
assert tensor not in param_to_save_data
|
||||
param_to_save_data[tensor] = record_tensor
|
||||
|
||||
del temp_chunk
|
||||
|
||||
param_to_save_data = self._get_param_to_save_data(self.fp32_params, only_rank_0)
|
||||
for (name, p), fp32_p in zip(self.named_parameters(), self.fp32_params):
|
||||
if p is not None:
|
||||
assert fp32_p in param_to_save_data, "Parameter '{}' is neglected in the chunk list".format(name)
|
||||
|
||||
Reference in New Issue
Block a user