mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 20:40:34 +00:00
[Tensor] add module check and bert test (#1031)
* add Embedding * Add bert test * polish * add check module test * polish * polish * polish * polish
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
from .utils import InsertPostInitMethodToModuleSubClasses
|
||||
import torch
|
||||
from colossalai.tensor import ColoTensor, ColoParameter, register_colo_module, init_colo_module, \
|
||||
ColoLinear
|
||||
ColoLinear, ColoEmbedding
|
||||
import types
|
||||
|
||||
from torch import nn
|
||||
@@ -137,7 +137,12 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||
torch.nn.Module.__setattr__ = _setattr_with_colotensor
|
||||
torch.nn.Module.register_parameter = _register_parameter_with_colotensor
|
||||
torch.nn.Module.get_parameter = _get_parameter_with_colotensor
|
||||
|
||||
self._register_colo_modules()
|
||||
|
||||
def _register_colo_modules(self):
|
||||
register_colo_module(torch.nn.Linear, ColoLinear())
|
||||
register_colo_module(torch.nn.Embedding, ColoEmbedding())
|
||||
|
||||
def _post_init_method(self, module: torch.nn.Module, *args, **kwargs):
|
||||
"""
|
||||
@@ -179,5 +184,6 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||
replaced_tensors[param] = colo_param
|
||||
delattr(submodule, param_name)
|
||||
setattr(submodule, param_name, colo_param)
|
||||
colo_param.shared_param_modules.append(submodule)
|
||||
|
||||
ColoModulize(module)
|
||||
|
Reference in New Issue
Block a user