mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 19:40:28 +00:00
[polish] polish ColoTensor and its submodules (#2537)
This commit is contained in:
@@ -37,12 +37,11 @@ def _convert_to_coloparam(param: torch.nn.Parameter,
|
||||
# detaching tensor is necessary for optimizers.
|
||||
requires_grad = param.requires_grad
|
||||
# param is the global tensor.
|
||||
|
||||
|
||||
if param.device.type == "meta":
|
||||
colo_param = ColoParameter(param, requires_grad=requires_grad)
|
||||
else:
|
||||
else:
|
||||
colo_param = ColoParameter(param.to(device=device, dtype=dtype), requires_grad=requires_grad)
|
||||
|
||||
|
||||
# if default_shard_plan exists, shard the param during initialization.
|
||||
# This can reduce the model size after initialization.
|
||||
@@ -129,32 +128,29 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||
delattr(submodule, param_name)
|
||||
setattr(submodule, param_name, colo_param)
|
||||
colo_param.shared_param_modules.append(submodule)
|
||||
|
||||
meta_param_flag = 0
|
||||
meta_buffer_flag = 0
|
||||
|
||||
param_number = 0
|
||||
meta_param_number = 0
|
||||
buffer_number = 0
|
||||
meta_buffer_number = 0
|
||||
|
||||
for param in module.parameters():
|
||||
if param.device.type=="meta":
|
||||
meta_param_flag = 1
|
||||
if meta_param_flag == 1 and param.device.type!="meta":
|
||||
raise ValueError("Meta parameters and valued parameters can not be in the same model")
|
||||
|
||||
param_number += 1
|
||||
meta_param_number += (param.device.type == 'meta')
|
||||
|
||||
for buffer in module.buffers():
|
||||
if buffer.device.type=="meta":
|
||||
meta_buffer_flag = 1
|
||||
if meta_buffer_flag == 1 and buffer.device.type!="meta":
|
||||
raise ValueError("Meta buffers and valued buffers can not be in the same model")
|
||||
|
||||
if meta_param_flag==1 and meta_buffer_flag==1:
|
||||
pass
|
||||
elif meta_buffer_flag==0 and meta_param_flag==1:
|
||||
for name, buf in module.named_buffers():
|
||||
module._buffers[name] = module._buffers[name].to(device=self._device)
|
||||
elif meta_param_flag==0 and meta_buffer_flag==1:
|
||||
for name, param in module.named_parameters():
|
||||
module._parameters[name] = module._parameters[name].to(device=self._device)
|
||||
else:
|
||||
module.to(self._device)
|
||||
|
||||
buffer_number += 1
|
||||
meta_buffer_number += (buffer.device.type == 'meta')
|
||||
|
||||
if meta_param_number > 0 and meta_param_number != param_number:
|
||||
raise ValueError("Meta parameters and valued parameters can not be in the same model")
|
||||
if meta_buffer_number > 0 and meta_buffer_number != buffer_number:
|
||||
raise ValueError("Meta buffers and valued buffers can not be in the same model")
|
||||
|
||||
if meta_buffer_number == 0:
|
||||
for buffer in module.buffers():
|
||||
buffer.data = buffer.data.to(device=self._device)
|
||||
|
||||
|
||||
def post_process_colo_init_ctx(model: torch.nn.Module,
|
||||
device: torch.device = torch.device('cpu'),
|
||||
|
Reference in New Issue
Block a user