mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-11 13:59:08 +00:00
[fx] refactor code for profiler / enable fake tensor movement. (#1646)
* [fx/profiling] provide summary for MetaInfoProp. * [fx/profiler] provide a table of summary. * [fx/profiler] provide a table of summary. * [fx/profiler] provide a table of summary. * [fx/profiler] provide a table of summary. * [fx] optimize table repr. * [fx] optimize table repr. * [fx] refactor code for profiler. * [fx] add docstring. * [fx] add docstring. * [fx] skip test. * [fx] redo. * [fx] redo. * [fx] fix import error for torch11. * [fx] fix import error for torch11.
This commit is contained in:
@@ -1,5 +1,9 @@
|
||||
from copy import deepcopy
|
||||
from typing import Optional, Union, overload
|
||||
import torch
|
||||
from torch.utils._pytree import tree_map, tree_flatten
|
||||
from torch.types import _bool, _dtype, _device
|
||||
from functools import singledispatchmethod
|
||||
|
||||
__all__ = ['MetaTensor']
|
||||
|
||||
@@ -16,6 +20,11 @@ class MetaTensor(torch.Tensor):
|
||||
|
||||
@staticmethod
|
||||
def __new__(cls, elem, fake_device=None):
|
||||
# Avoid multiple wrapping
|
||||
if isinstance(elem, MetaTensor):
|
||||
fake_device = elem.device if fake_device is None else fake_device
|
||||
elem = elem._tensor
|
||||
|
||||
# The wrapping tensor (MetaTensor) shouldn't hold any
|
||||
# memory for the class in question, but it should still
|
||||
# advertise the same device as before
|
||||
@@ -74,3 +83,32 @@ class MetaTensor(torch.Tensor):
|
||||
return MetaTensor(x, fake_device=fake_device) if isinstance(x, torch.Tensor) else x
|
||||
|
||||
return tree_map(wrap, out)
|
||||
|
||||
@singledispatchmethod
|
||||
def to(self, *args, **kwargs) -> torch.Tensor:
|
||||
"""An extension of `torch.Tensor.to()` to MetaTensor
|
||||
|
||||
Returns:
|
||||
result (MetaTensor): MetaTensor
|
||||
|
||||
Usage:
|
||||
>>> tensor = MetaTensor(torch.rand(10), fake_device='cuda:100')
|
||||
>>> tensor.to(torch.uint8)
|
||||
MetaTensor(tensor(..., device='meta', size=(10,), dtype=torch.uint8), fake_device='cuda:100')
|
||||
>>> tensor.to(torch.device('cuda:42'))
|
||||
MetaTensor(tensor(..., device='meta', size=(10,)), fake_device='cuda:42')
|
||||
>>> tensor.to('vulkan')
|
||||
MetaTensor(tensor(..., device='meta', size=(10,)), fake_device='vulkan')
|
||||
"""
|
||||
# this imitates c++ function in the way of @overload
|
||||
return super().to(*args, **kwargs)
|
||||
|
||||
@to.register
|
||||
def _(self, device: str, dtype: Optional[_dtype] = None, non_blocking: _bool = False, copy: _bool = False) -> torch.Tensor:
|
||||
result = super().to(dtype, non_blocking, copy) if dtype is not None else self
|
||||
return MetaTensor(deepcopy(result), fake_device=device)
|
||||
|
||||
@to.register
|
||||
def _(self, device: _device, dtype: Optional[_dtype] = None, non_blocking: _bool = False, copy: _bool = False) -> torch.Tensor:
|
||||
result = super().to(dtype, non_blocking, copy) if dtype is not None else self
|
||||
return MetaTensor(deepcopy(result), fake_device=device)
|
||||
|
Reference in New Issue
Block a user