[fx] refactor code for profiler / enable fake tensor movement. (#1646)

* [fx/profiling] provide summary for MetaInfoProp.

* [fx/profiler] provide a table of summary.

* [fx/profiler] provide a table of summary.

* [fx/profiler] provide a table of summary.

* [fx/profiler] provide a table of summary.

* [fx] optimize table repr.

* [fx] optimize table repr.

* [fx] refactor code for profiler.

* [fx] add docstring.

* [fx] add docstring.

* [fx] skip test.

* [fx] redo.

* [fx] redo.

* [fx] fix import error for torch11.

* [fx] fix import error for torch11.
This commit is contained in:
Super Daniel
2022-09-27 10:26:52 +08:00
committed by GitHub
parent 5d0fdb9cb4
commit 6135e178b3
5 changed files with 103 additions and 67 deletions

View File

@@ -1,5 +1,9 @@
from copy import deepcopy
from typing import Optional, Union, overload
import torch
from torch.utils._pytree import tree_map, tree_flatten
from torch.types import _bool, _dtype, _device
from functools import singledispatchmethod
__all__ = ['MetaTensor']
@@ -16,6 +20,11 @@ class MetaTensor(torch.Tensor):
@staticmethod
def __new__(cls, elem, fake_device=None):
# Avoid multiple wrapping
if isinstance(elem, MetaTensor):
fake_device = elem.device if fake_device is None else fake_device
elem = elem._tensor
# The wrapping tensor (MetaTensor) shouldn't hold any
# memory for the class in question, but it should still
# advertise the same device as before
@@ -74,3 +83,32 @@ class MetaTensor(torch.Tensor):
return MetaTensor(x, fake_device=fake_device) if isinstance(x, torch.Tensor) else x
return tree_map(wrap, out)
@singledispatchmethod
def to(self, *args, **kwargs) -> torch.Tensor:
"""An extension of `torch.Tensor.to()` to MetaTensor
Returns:
result (MetaTensor): MetaTensor
Usage:
>>> tensor = MetaTensor(torch.rand(10), fake_device='cuda:100')
>>> tensor.to(torch.uint8)
MetaTensor(tensor(..., device='meta', size=(10,), dtype=torch.uint8), fake_device='cuda:100')
>>> tensor.to(torch.device('cuda:42'))
MetaTensor(tensor(..., device='meta', size=(10,)), fake_device='cuda:42')
>>> tensor.to('vulkan')
MetaTensor(tensor(..., device='meta', size=(10,)), fake_device='vulkan')
"""
# this imitates c++ function in the way of @overload
return super().to(*args, **kwargs)
@to.register
def _(self, device: str, dtype: Optional[_dtype] = None, non_blocking: _bool = False, copy: _bool = False) -> torch.Tensor:
result = super().to(dtype, non_blocking, copy) if dtype is not None else self
return MetaTensor(deepcopy(result), fake_device=device)
@to.register
def _(self, device: _device, dtype: Optional[_dtype] = None, non_blocking: _bool = False, copy: _bool = False) -> torch.Tensor:
result = super().to(dtype, non_blocking, copy) if dtype is not None else self
return MetaTensor(deepcopy(result), fake_device=device)