[fx/profiler] tuned the calculation of memory estimation (#1619)

* [fx] tuned the meta info and rotor solver.

* [fx] remove import.

* [fx] remove import.

* [fx] remove import.

* [fx] tune the meta calculations.

* [fx] polish comments.

* [fx] remove assertions.

* [fx] modify test cases.

* [fx] modify test cases.

* [fx] optimize import.

* [fx
This commit is contained in:
Super Daniel
2022-09-23 10:59:47 +08:00
committed by GitHub
parent f7f2248771
commit d967779a32
16 changed files with 413 additions and 207 deletions

View File

@@ -7,6 +7,7 @@ __all__ = ['MetaTensor']
class MetaTensor(torch.Tensor):
"""
A wrapping tensor that hacks `torch.autograd` without patching more `torch.ops.aten` ops.
`fake_device` is the device that `MetaTensor` is supposed to run on.
"""
_tensor: torch.Tensor
@@ -14,7 +15,7 @@ class MetaTensor(torch.Tensor):
__slots__ = ['_tensor']
@staticmethod
def __new__(cls, elem):
def __new__(cls, elem, fake_device=None):
# The wrapping tensor (MetaTensor) shouldn't hold any
# memory for the class in question, but it should still
# advertise the same device as before
@@ -25,24 +26,37 @@ class MetaTensor(torch.Tensor):
storage_offset=elem.storage_offset(),
dtype=elem.dtype,
layout=elem.layout,
device='cpu',
device=fake_device if fake_device is not None else elem.device,
requires_grad=elem.requires_grad) # deceive the frontend for aten selections
r._tensor = elem
# ...the real tensor is held as an element on the tensor.
if not r._tensor.is_meta:
r._tensor = r._tensor.to(torch.device('meta'))
# only tensor not on `meta` should be copied to `meta`
return r
def __repr__(self):
if self.grad_fn:
return f"MetaTensor({self._tensor}, grad_fn={self.grad_fn})"
return f"MetaTensor({self._tensor})"
return f"MetaTensor({self._tensor}, fake_device='{self.device}', grad_fn={self.grad_fn})"
return f"MetaTensor({self._tensor}, fake_device='{self.device}')"
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
fake_device = None
def unwrap(x):
if isinstance(x, torch.Tensor) and not hasattr(x, '_tensor'):
x = MetaTensor(x)
return x._tensor.to('meta') if isinstance(x, MetaTensor) else x
nonlocal fake_device
if isinstance(x, MetaTensor):
fake_device = x.device
x = x._tensor
elif isinstance(x, torch.Tensor):
fake_device = x.device
x = x.to(torch.device('meta'))
return x
if 'device' in kwargs:
fake_device = kwargs['device']
kwargs['device'] = torch.device('meta')
args = tree_map(unwrap, args)
kwargs = tree_map(unwrap, kwargs)
@@ -53,6 +67,10 @@ class MetaTensor(torch.Tensor):
# Now, we want to continue propagating this tensor, so we rewrap Tensors in
# our custom tensor subclass
def wrap(x):
return MetaTensor(x) if isinstance(x, torch.Tensor) else x
if isinstance(x, torch.Tensor):
nonlocal fake_device
if not x.is_meta:
x = x.to(torch.device('meta'))
return MetaTensor(x, fake_device=fake_device) if isinstance(x, torch.Tensor) else x
return tree_map(wrap, out)