[fx] added testing for all bert variants (#1207)

* [fx] added testing for all bert variants

* polish code
This commit is contained in:
Frank Lee
2022-07-06 10:50:49 +08:00
committed by GitHub
parent b5f25eb32a
commit 426a279ce7
2 changed files with 88 additions and 20 deletions

View File

@@ -1,6 +1,8 @@
import operator
import torch
from torch.fx.proxy import Proxy, Attribute
from typing import List, Union
from torch.utils._pytree import PyTree
__all__ = ['ColoProxy']
@@ -26,8 +28,12 @@ class ColoProxy(Proxy):
return self._meta_tensor
@meta_tensor.setter
def meta_tensor(self, tensor: torch.Tensor):
assert tensor is None or tensor.is_meta, 'Expected to receive a meta tensor, but got a non-meta tensor'
def meta_tensor(self, tensor: Union[List[torch.Tensor], torch.Tensor]):
def _is_meta(item):
assert torch.is_tensor(item) and item.is_meta
torch.fx.node.map_aggregate(tensor, _is_meta)
self._meta_tensor = tensor
@property
@@ -83,6 +89,14 @@ class ColoProxy(Proxy):
def __setitem__(self, indices, values):
return self.tracer.create_proxy("call_function", operator.setitem, (self, indices, values), {})
def __contains__(self, key):
if self.node.op == "placeholder":
# this is used to handle like
# if x in kwargs
# we don't handle this case for now
return False
return super().__contains__(key)
class ColoAttribute(ColoProxy):