mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 11:32:10 +00:00
[Tensor] add from_pretrained support and bert pretrained test (#921)
* add from_pretrained support and test * polish * polish * polish * polish
This commit is contained in:
@@ -4,8 +4,86 @@ from colossalai.tensor import ColoTensor, ColoParameter
|
||||
import types
|
||||
|
||||
from torch import nn
|
||||
from typing import Iterator, Tuple, Union
|
||||
from typing import Iterator, Tuple, Union, Optional
|
||||
|
||||
# Adapted from torch.nn.module.Module.register_param
|
||||
def _register_parameter_with_colotensor(self, name: str, param):
|
||||
if '_parameters' not in self.__dict__:
|
||||
raise AttributeError(
|
||||
"cannot assign parameter before Module.__init__() call")
|
||||
|
||||
if not isinstance(name, torch._six.string_classes):
|
||||
raise TypeError("parameter name should be a string. "
|
||||
"Got {}".format(torch.typename(name)))
|
||||
if '.' in name:
|
||||
raise KeyError("parameter name can't contain \".\"")
|
||||
if name == '':
|
||||
raise KeyError("parameter name can't be empty string \"\"")
|
||||
if hasattr(self, name) and name not in self._parameters:
|
||||
raise KeyError("attribute '{}' already exists".format(name))
|
||||
|
||||
if param is None:
|
||||
self._parameters[name] = None
|
||||
elif not isinstance(param, (torch.nn.Parameter, ColoParameter)):
|
||||
raise TypeError("cannot assign '{}' object to parameter '{}' "
|
||||
"(torch.nn.Parameter or ColoParameter or None required)"
|
||||
.format(torch.typename(param), name))
|
||||
elif param.grad_fn:
|
||||
raise ValueError(
|
||||
"Cannot assign non-leaf Tensor to parameter '{0}'. Model "
|
||||
"parameters must be created explicitly. To express '{0}' "
|
||||
"as a function of another Tensor, compute the value in "
|
||||
"the forward() method.".format(name))
|
||||
else:
|
||||
self._parameters[name] = param
|
||||
|
||||
# Adapted from torch.nn.module.Module.__setattr__
|
||||
def _setattr_with_colotensor(self, name: str, value: Union[torch.Tensor, torch.nn.Module, ColoTensor]):
|
||||
def remove_from(*dicts_or_sets):
|
||||
for d in dicts_or_sets:
|
||||
if name in d:
|
||||
if isinstance(d, dict):
|
||||
del d[name]
|
||||
else:
|
||||
d.discard(name)
|
||||
|
||||
params = self.__dict__.get('_parameters')
|
||||
if isinstance(value, (ColoTensor, torch.nn.Parameter)):
|
||||
if params is None:
|
||||
raise AttributeError(
|
||||
"cannot assign parameters before Module.__init__() call")
|
||||
remove_from(self.__dict__, self._buffers, self._modules, self._non_persistent_buffers_set)
|
||||
self.register_parameter(name, value)
|
||||
elif params is not None and name in params:
|
||||
if value is not None:
|
||||
raise TypeError("cannot assign '{}' as parameter '{}' "
|
||||
"(torch.nn.Parameter or None expected)"
|
||||
.format(torch.typename(value), name))
|
||||
self.register_parameter(name, value)
|
||||
else:
|
||||
modules = self.__dict__.get('_modules')
|
||||
if isinstance(value, torch.nn.Module):
|
||||
if modules is None:
|
||||
raise AttributeError(
|
||||
"cannot assign module before Module.__init__() call")
|
||||
remove_from(self.__dict__, self._parameters, self._buffers, self._non_persistent_buffers_set)
|
||||
modules[name] = value
|
||||
elif modules is not None and name in modules:
|
||||
if value is not None:
|
||||
raise TypeError("cannot assign '{}' as child module '{}' "
|
||||
"(torch.nn.Module or None expected)"
|
||||
.format(torch.typename(value), name))
|
||||
modules[name] = value
|
||||
else:
|
||||
buffers = self.__dict__.get('_buffers')
|
||||
if buffers is not None and name in buffers:
|
||||
if value is not None and not isinstance(value, torch.Tensor):
|
||||
raise TypeError("cannot assign '{}' as buffer '{}' "
|
||||
"(torch.Tensor or None expected)"
|
||||
.format(torch.typename(value), name))
|
||||
buffers[name] = value
|
||||
else:
|
||||
object.__setattr__(self, name, value)
|
||||
|
||||
def ColoModulize(module):
|
||||
"""
|
||||
@@ -64,7 +142,6 @@ def ColoModulize(module):
|
||||
module.colo_named_parameters = funcType(colo_named_parameters, module)
|
||||
module._colo_visited = True
|
||||
|
||||
|
||||
class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||
|
||||
def __init__(self, lazy_memory_allocate: bool = False, device: torch.device = torch.device('cpu')):
|
||||
@@ -77,11 +154,16 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||
self._lazy_memory_allocate = lazy_memory_allocate
|
||||
self._device = device
|
||||
|
||||
# TODO(jzy) replace it with old __setattr__ in the exit() of context?
|
||||
torch.nn.Module.__setattr__ = _setattr_with_colotensor
|
||||
torch.nn.Module.register_parameter = _register_parameter_with_colotensor
|
||||
|
||||
def _post_init_method(self, module: torch.nn.Module, *args, **kwargs):
|
||||
"""
|
||||
The function to call at the end of the constructor of each module.
|
||||
FIXME(fjr) The module may be passed to this function multiple times?
|
||||
"""
|
||||
|
||||
if hasattr(module, '_colo_visited'):
|
||||
return
|
||||
|
||||
@@ -100,7 +182,7 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||
tensor_detached = param.to(self._device).detach()
|
||||
tensor_detached.requires_grad = requires_grad
|
||||
|
||||
setattr(module, name,
|
||||
ColoParameter.init_from_torch_tensor(tensor=tensor_detached, save_payload=save_torch_payload))
|
||||
colo_param = ColoParameter.init_from_torch_tensor(tensor=tensor_detached, save_payload=save_torch_payload)
|
||||
setattr(module, name, colo_param)
|
||||
|
||||
ColoModulize(module)
|
||||
ColoModulize(module)
|
Reference in New Issue
Block a user