[fx] supported data-dependent control flow in model tracing (#1185)

* [fx] supported data-dependent control flow in model tracing

* polish code
This commit is contained in:
Frank Lee
2022-06-29 15:05:25 +08:00
committed by GitHub
parent c463f8adf9
commit 6d86f1bc91
10 changed files with 461 additions and 0 deletions

View File

@@ -37,6 +37,12 @@ class ColoProxy(Proxy):
def _assert_has_meta(self):
assert self.has_meta_tensor, f'Meta tensor is not set for {self.node.name}'
@property
def device(self):
# Hack so we can track when devices are used. During meta-tensor propagation,
# replace these values with a constant 'meta'
return MetaDeviceAttribute(self, "device")
@property
def dtype(self):
self._assert_has_meta()
@@ -72,3 +78,27 @@ class ColoProxy(Proxy):
def __setitem__(self, indices, values):
return self.tracer.create_proxy("call_function", operator.setitem, (self, indices, values), {})
class ColoAttribute(ColoProxy):
def __init__(self, root, attr: str):
# this class is copied from torch.fx.Attribute
# but inherits ColoProxy
self.root = root
self.attr = attr
self.tracer = root.tracer
self._node = None
@property
def node(self):
if self._node is None:
self._node = self.tracer.create_proxy("call_function", getattr, (self.root, self.attr), {}).node
return self._node
def __call__(self, *args, **kwargs):
return self.tracer.create_proxy("call_method", self.attr, (self.root,) + args, kwargs)
class MetaDeviceAttribute(ColoAttribute):
pass