[Tensor] add from_pretrained support and bert pretrained test (#921)

* add from_pretrained support and test

* polish

* polish

* polish

* polish
This commit is contained in:
Ziyue Jiang
2022-05-09 16:11:47 +08:00
committed by GitHub
parent 1d625fcd36
commit c195d2814c
4 changed files with 158 additions and 20 deletions

View File

@@ -4,8 +4,86 @@ from colossalai.tensor import ColoTensor, ColoParameter
import types
from torch import nn
from typing import Iterator, Tuple, Union
from typing import Iterator, Tuple, Union, Optional
# Adapted from torch.nn.module.Module.register_param
def _register_parameter_with_colotensor(self, name: str, param):
if '_parameters' not in self.__dict__:
raise AttributeError(
"cannot assign parameter before Module.__init__() call")
if not isinstance(name, torch._six.string_classes):
raise TypeError("parameter name should be a string. "
"Got {}".format(torch.typename(name)))
if '.' in name:
raise KeyError("parameter name can't contain \".\"")
if name == '':
raise KeyError("parameter name can't be empty string \"\"")
if hasattr(self, name) and name not in self._parameters:
raise KeyError("attribute '{}' already exists".format(name))
if param is None:
self._parameters[name] = None
elif not isinstance(param, (torch.nn.Parameter, ColoParameter)):
raise TypeError("cannot assign '{}' object to parameter '{}' "
"(torch.nn.Parameter or ColoParameter or None required)"
.format(torch.typename(param), name))
elif param.grad_fn:
raise ValueError(
"Cannot assign non-leaf Tensor to parameter '{0}'. Model "
"parameters must be created explicitly. To express '{0}' "
"as a function of another Tensor, compute the value in "
"the forward() method.".format(name))
else:
self._parameters[name] = param
# Adapted from torch.nn.module.Module.__setattr__
def _setattr_with_colotensor(self, name: str, value: Union[torch.Tensor, torch.nn.Module, ColoTensor]):
def remove_from(*dicts_or_sets):
for d in dicts_or_sets:
if name in d:
if isinstance(d, dict):
del d[name]
else:
d.discard(name)
params = self.__dict__.get('_parameters')
if isinstance(value, (ColoTensor, torch.nn.Parameter)):
if params is None:
raise AttributeError(
"cannot assign parameters before Module.__init__() call")
remove_from(self.__dict__, self._buffers, self._modules, self._non_persistent_buffers_set)
self.register_parameter(name, value)
elif params is not None and name in params:
if value is not None:
raise TypeError("cannot assign '{}' as parameter '{}' "
"(torch.nn.Parameter or None expected)"
.format(torch.typename(value), name))
self.register_parameter(name, value)
else:
modules = self.__dict__.get('_modules')
if isinstance(value, torch.nn.Module):
if modules is None:
raise AttributeError(
"cannot assign module before Module.__init__() call")
remove_from(self.__dict__, self._parameters, self._buffers, self._non_persistent_buffers_set)
modules[name] = value
elif modules is not None and name in modules:
if value is not None:
raise TypeError("cannot assign '{}' as child module '{}' "
"(torch.nn.Module or None expected)"
.format(torch.typename(value), name))
modules[name] = value
else:
buffers = self.__dict__.get('_buffers')
if buffers is not None and name in buffers:
if value is not None and not isinstance(value, torch.Tensor):
raise TypeError("cannot assign '{}' as buffer '{}' "
"(torch.Tensor or None expected)"
.format(torch.typename(value), name))
buffers[name] = value
else:
object.__setattr__(self, name, value)
def ColoModulize(module):
"""
@@ -64,7 +142,6 @@ def ColoModulize(module):
module.colo_named_parameters = funcType(colo_named_parameters, module)
module._colo_visited = True
class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
def __init__(self, lazy_memory_allocate: bool = False, device: torch.device = torch.device('cpu')):
@@ -77,11 +154,16 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
self._lazy_memory_allocate = lazy_memory_allocate
self._device = device
# TODO(jzy) replace it with old __setattr__ in the exit() of context?
torch.nn.Module.__setattr__ = _setattr_with_colotensor
torch.nn.Module.register_parameter = _register_parameter_with_colotensor
def _post_init_method(self, module: torch.nn.Module, *args, **kwargs):
"""
The function to call at the end of the constructor of each module.
FIXME(fjr) The module may be passed to this function multiple times?
"""
if hasattr(module, '_colo_visited'):
return
@@ -100,7 +182,7 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
tensor_detached = param.to(self._device).detach()
tensor_detached.requires_grad = requires_grad
setattr(module, name,
ColoParameter.init_from_torch_tensor(tensor=tensor_detached, save_payload=save_torch_payload))
colo_param = ColoParameter.init_from_torch_tensor(tensor=tensor_detached, save_payload=save_torch_payload)
setattr(module, name, colo_param)
ColoModulize(module)
ColoModulize(module)