mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 02:51:59 +00:00
[fx] added unit test for coloproxy (#1119)
* [fx] added unit test for coloproxy * polish code * polish code
This commit is contained in:
@@ -19,16 +19,16 @@ class ColoProxy(Proxy):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.meta_tensor = None
|
||||
self._meta_tensor = None
|
||||
|
||||
@property
|
||||
def meta_tensor(self):
|
||||
return self.meta_tensor
|
||||
return self._meta_tensor
|
||||
|
||||
@meta_tensor.setter
|
||||
def meta_tensor(self, tensor: torch.Tensor):
|
||||
assert tensor.is_meta, 'Expected to receive a meta tensor, but got a non-meta tensor'
|
||||
self.meta_tensor = tensor
|
||||
assert tensor is None or tensor.is_meta, 'Expected to receive a meta tensor, but got a non-meta tensor'
|
||||
self._meta_tensor = tensor
|
||||
|
||||
@property
|
||||
def has_meta_tensor(self):
|
||||
@@ -42,6 +42,19 @@ class ColoProxy(Proxy):
|
||||
self._assert_has_meta()
|
||||
return self.meta_tensor.dtype
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
self._assert_has_meta()
|
||||
return self.meta_tensor.shape
|
||||
|
||||
def dim(self):
|
||||
self._assert_has_meta()
|
||||
return self.meta_tensor.dim()
|
||||
|
||||
def size(self, dim: int = None):
|
||||
self._assert_has_meta()
|
||||
return self.meta_tensor.size(dim=dim)
|
||||
|
||||
def __len__(self):
|
||||
self._assert_has_meta()
|
||||
return len(self.meta_tensor)
|
||||
|
Reference in New Issue
Block a user