mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-12 12:47:21 +00:00
[Tensor] add module handler for linear (#1021)
* add module spec for linear * polish * polish * polish
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
from .utils import InsertPostInitMethodToModuleSubClasses
|
||||
import torch
|
||||
from colossalai.tensor import ColoTensor, ColoParameter
|
||||
from colossalai.tensor import ColoTensor, ColoParameter, register_colo_module, init_colo_module, \
|
||||
ColoLinear
|
||||
import types
|
||||
|
||||
from torch import nn
|
||||
@@ -101,6 +102,17 @@ def _setattr_with_colotensor(self, name: str, value: Union[torch.Tensor, torch.n
|
||||
else:
|
||||
object.__setattr__(self, name, value)
|
||||
|
||||
def _get_parameter_with_colotensor(self, target: str) -> Union[torch.nn.Parameter, ColoTensor]:
|
||||
module_path, _, param_name = target.rpartition(".")
|
||||
|
||||
mod: torch.nn.Module = self.get_submodule(module_path)
|
||||
|
||||
if not hasattr(mod, param_name):
|
||||
raise AttributeError(mod._get_name() + " has no attribute `"
|
||||
+ param_name + "`")
|
||||
|
||||
param = getattr(mod, param_name)
|
||||
return param
|
||||
|
||||
def ColoModulize(module):
|
||||
"""
|
||||
@@ -124,6 +136,8 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||
|
||||
torch.nn.Module.__setattr__ = _setattr_with_colotensor
|
||||
torch.nn.Module.register_parameter = _register_parameter_with_colotensor
|
||||
torch.nn.Module.get_parameter = _get_parameter_with_colotensor
|
||||
register_colo_module(torch.nn.Linear, ColoLinear())
|
||||
|
||||
def _post_init_method(self, module: torch.nn.Module, *args, **kwargs):
|
||||
"""
|
||||
|
Reference in New Issue
Block a user