[polish] polish ColoTensor and its submodules (#2537)

This commit is contained in:
HELSON
2023-02-03 11:44:10 +08:00
committed by GitHub
parent 51d4d6e718
commit 552183bb74
6 changed files with 75 additions and 65 deletions

View File

@@ -37,12 +37,11 @@ def _convert_to_coloparam(param: torch.nn.Parameter,
# detaching tensor is necessary for optimizers.
requires_grad = param.requires_grad
# param is the global tensor.
if param.device.type == "meta":
colo_param = ColoParameter(param, requires_grad=requires_grad)
else:
else:
colo_param = ColoParameter(param.to(device=device, dtype=dtype), requires_grad=requires_grad)
# if default_shard_plan exists, shard the param during initialization.
# This can reduce the model size after initialization.
@@ -129,32 +128,29 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
delattr(submodule, param_name)
setattr(submodule, param_name, colo_param)
colo_param.shared_param_modules.append(submodule)
meta_param_flag = 0
meta_buffer_flag = 0
param_number = 0
meta_param_number = 0
buffer_number = 0
meta_buffer_number = 0
for param in module.parameters():
if param.device.type=="meta":
meta_param_flag = 1
if meta_param_flag == 1 and param.device.type!="meta":
raise ValueError("Meta parameters and valued parameters can not be in the same model")
param_number += 1
meta_param_number += (param.device.type == 'meta')
for buffer in module.buffers():
if buffer.device.type=="meta":
meta_buffer_flag = 1
if meta_buffer_flag == 1 and buffer.device.type!="meta":
raise ValueError("Meta buffers and valued buffers can not be in the same model")
if meta_param_flag==1 and meta_buffer_flag==1:
pass
elif meta_buffer_flag==0 and meta_param_flag==1:
for name, buf in module.named_buffers():
module._buffers[name] = module._buffers[name].to(device=self._device)
elif meta_param_flag==0 and meta_buffer_flag==1:
for name, param in module.named_parameters():
module._parameters[name] = module._parameters[name].to(device=self._device)
else:
module.to(self._device)
buffer_number += 1
meta_buffer_number += (buffer.device.type == 'meta')
if meta_param_number > 0 and meta_param_number != param_number:
raise ValueError("Meta parameters and valued parameters can not be in the same model")
if meta_buffer_number > 0 and meta_buffer_number != buffer_number:
raise ValueError("Meta buffers and valued buffers can not be in the same model")
if meta_buffer_number == 0:
for buffer in module.buffers():
buffer.data = buffer.data.to(device=self._device)
def post_process_colo_init_ctx(model: torch.nn.Module,
device: torch.device = torch.device('cpu'),