mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 04:24:47 +00:00
[Tensor] add module check and bert test (#1031)
* add Embedding * Add bert test * polish * add check module test * polish * polish * polish * polish
This commit is contained in:
@@ -26,6 +26,13 @@ class ColoParameter(ColoTensor):
|
||||
self._type = TensorType.MODEL
|
||||
self._graph_node = None
|
||||
|
||||
# a list contains modules sharing this ColoParameter with others.
|
||||
self._shared_param_modules = []
|
||||
|
||||
@property
|
||||
def shared_param_modules(self):
|
||||
return self._shared_param_modules
|
||||
|
||||
@staticmethod
|
||||
def from_torch_tensor(tensor: torch.Tensor,
|
||||
requires_grad: bool = True,
|
||||
@@ -36,3 +43,4 @@ class ColoParameter(ColoTensor):
|
||||
|
||||
def __repr__(self):
|
||||
return f'ColoParameter: {torch.Tensor.__repr__(self)}'
|
||||
|
||||
|
Reference in New Issue
Block a user