[Tensor] add module check and bert test (#1031)

* add Embedding

* Add bert test

* polish

* add check module test

* polish

* polish

* polish

* polish
This commit is contained in:
Ziyue Jiang
2022-05-26 18:15:42 +08:00
committed by GitHub
parent 7106bd671d
commit 6c5996a56e
10 changed files with 170 additions and 45 deletions

View File

@@ -26,6 +26,13 @@ class ColoParameter(ColoTensor):
self._type = TensorType.MODEL
self._graph_node = None
# a list contains modules sharing this ColoParameter with others.
self._shared_param_modules = []
@property
def shared_param_modules(self):
return self._shared_param_modules
@staticmethod
def from_torch_tensor(tensor: torch.Tensor,
requires_grad: bool = True,
@@ -36,3 +43,4 @@ class ColoParameter(ColoTensor):
def __repr__(self):
return f'ColoParameter: {torch.Tensor.__repr__(self)}'