[Tensor] add module handler for linear (#1021)

* add module spec for linear

* polish

* polish

* polish
This commit is contained in:
Ziyue Jiang
2022-05-26 11:50:44 +08:00
committed by GitHub
parent ee50497db2
commit 32291dd73f
7 changed files with 341 additions and 2 deletions

View File

@@ -1,6 +1,7 @@
from .utils import InsertPostInitMethodToModuleSubClasses
import torch
from colossalai.tensor import ColoTensor, ColoParameter
from colossalai.tensor import ColoTensor, ColoParameter, register_colo_module, init_colo_module, \
ColoLinear
import types
from torch import nn
@@ -101,6 +102,17 @@ def _setattr_with_colotensor(self, name: str, value: Union[torch.Tensor, torch.n
else:
object.__setattr__(self, name, value)
def _get_parameter_with_colotensor(self, target: str) -> Union[torch.nn.Parameter, ColoTensor]:
module_path, _, param_name = target.rpartition(".")
mod: torch.nn.Module = self.get_submodule(module_path)
if not hasattr(mod, param_name):
raise AttributeError(mod._get_name() + " has no attribute `"
+ param_name + "`")
param = getattr(mod, param_name)
return param
def ColoModulize(module):
"""
@@ -124,6 +136,8 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
torch.nn.Module.__setattr__ = _setattr_with_colotensor
torch.nn.Module.register_parameter = _register_parameter_with_colotensor
torch.nn.Module.get_parameter = _get_parameter_with_colotensor
register_colo_module(torch.nn.Linear, ColoLinear())
def _post_init_method(self, module: torch.nn.Module, *args, **kwargs):
"""