[Tensor] add module check and bert test (#1031)

* add Embedding

* Add bert test

* polish

* add check module test

* polish

* polish

* polish

* polish
This commit is contained in:
Ziyue Jiang
2022-05-26 18:15:42 +08:00
committed by GitHub
parent 7106bd671d
commit 6c5996a56e
10 changed files with 170 additions and 45 deletions

View File

@@ -1,7 +1,7 @@
from .utils import InsertPostInitMethodToModuleSubClasses
import torch
from colossalai.tensor import ColoTensor, ColoParameter, register_colo_module, init_colo_module, \
ColoLinear
ColoLinear, ColoEmbedding
import types
from torch import nn
@@ -137,7 +137,12 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
torch.nn.Module.__setattr__ = _setattr_with_colotensor
torch.nn.Module.register_parameter = _register_parameter_with_colotensor
torch.nn.Module.get_parameter = _get_parameter_with_colotensor
self._register_colo_modules()
def _register_colo_modules(self):
register_colo_module(torch.nn.Linear, ColoLinear())
register_colo_module(torch.nn.Embedding, ColoEmbedding())
def _post_init_method(self, module: torch.nn.Module, *args, **kwargs):
"""
@@ -179,5 +184,6 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
replaced_tensors[param] = colo_param
delattr(submodule, param_name)
setattr(submodule, param_name, colo_param)
colo_param.shared_param_modules.append(submodule)
ColoModulize(module)