[fx] supported model tracing for huggingface bert (#1201)

* [fx] supported model tracing for huggingface bert

* polish test
This commit is contained in:
Frank Lee
2022-07-05 13:19:57 +08:00
committed by GitHub
parent 060b917daf
commit f7878f465c
5 changed files with 126 additions and 4 deletions

View File

@@ -59,7 +59,11 @@ class ColoProxy(Proxy):
def size(self, dim: int = None):
self._assert_has_meta()
return self.meta_tensor.size(dim=dim)
if dim:
return self.meta_tensor.size(dim=dim)
else:
# size(dim=None) will trigger runtime error for meta tensor
return self.meta_tensor.size()
def __len__(self):
self._assert_has_meta()