mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 18:40:28 +00:00
[fx] supported data-dependent control flow in model tracing (#1185)
* [fx] supported data-dependent control flow in model tracing * polish code
This commit is contained in:
@@ -37,6 +37,12 @@ class ColoProxy(Proxy):
|
||||
def _assert_has_meta(self):
|
||||
assert self.has_meta_tensor, f'Meta tensor is not set for {self.node.name}'
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
# Hack so we can track when devices are used. During meta-tensor propagation,
|
||||
# replace these values with a constant 'meta'
|
||||
return MetaDeviceAttribute(self, "device")
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
self._assert_has_meta()
|
||||
@@ -72,3 +78,27 @@ class ColoProxy(Proxy):
|
||||
|
||||
def __setitem__(self, indices, values):
|
||||
return self.tracer.create_proxy("call_function", operator.setitem, (self, indices, values), {})
|
||||
|
||||
|
||||
class ColoAttribute(ColoProxy):
|
||||
|
||||
def __init__(self, root, attr: str):
|
||||
# this class is copied from torch.fx.Attribute
|
||||
# but inherits ColoProxy
|
||||
self.root = root
|
||||
self.attr = attr
|
||||
self.tracer = root.tracer
|
||||
self._node = None
|
||||
|
||||
@property
|
||||
def node(self):
|
||||
if self._node is None:
|
||||
self._node = self.tracer.create_proxy("call_function", getattr, (self.root, self.attr), {}).node
|
||||
return self._node
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.tracer.create_proxy("call_method", self.attr, (self.root,) + args, kwargs)
|
||||
|
||||
|
||||
class MetaDeviceAttribute(ColoAttribute):
|
||||
pass
|
||||
|
Reference in New Issue
Block a user