[fx] added unit test for coloproxy (#1119)

* [fx] added unit test for coloproxy

* polish code

* polish code
This commit is contained in:
Frank Lee
2022-06-15 15:27:51 +08:00
committed by GitHub
parent 7d14b473f0
commit 16302a5359
2 changed files with 40 additions and 4 deletions

View File

@@ -19,16 +19,16 @@ class ColoProxy(Proxy):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.meta_tensor = None
self._meta_tensor = None
@property
def meta_tensor(self):
return self.meta_tensor
return self._meta_tensor
@meta_tensor.setter
def meta_tensor(self, tensor: torch.Tensor):
assert tensor.is_meta, 'Expected to receive a meta tensor, but got a non-meta tensor'
self.meta_tensor = tensor
assert tensor is None or tensor.is_meta, 'Expected to receive a meta tensor, but got a non-meta tensor'
self._meta_tensor = tensor
@property
def has_meta_tensor(self):
@@ -42,6 +42,19 @@ class ColoProxy(Proxy):
self._assert_has_meta()
return self.meta_tensor.dtype
@property
def shape(self):
self._assert_has_meta()
return self.meta_tensor.shape
def dim(self):
self._assert_has_meta()
return self.meta_tensor.dim()
def size(self, dim: int = None):
self._assert_has_meta()
return self.meta_tensor.size(dim=dim)
def __len__(self):
self._assert_has_meta()
return len(self.meta_tensor)