mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 19:40:28 +00:00
[fx] provide a stable but not accurate enough version of profiler. (#1547)
* [fx] compute memory stat and flop count for MetaInfoProp. * [fx] modify node attribute. * [fx] modify ckpt_chen. * [fx] fix compatibility. * [fx] fix import error. * [fx] skip test for MetaInfoProp. * [fx] skip test for MetaInfoProp. * [fx] skip test for MetaInfoProp. * [fx] skip test for MetaInfoProp. * [fx] skip if torch 1.11.0. * [fx] recover MetaInfoProp support for PyTorch 1.11. * [fx] provide a stable but not accurate enough version of profiler. * [fx] provide a stable but not accurate enough version of profiler. * [fx] fix compatibility in tests. * [fx] fix compatibility in tests. * [fx] fix compatibility in tests. * [fx] fix compatibility in tests. * [fx] fix compatibility in tests. * [fx] fix compatibility in tests. * [fx] fix compatibility in tests. * [fx] fix compatibility in tests. * [fx] fix compatibility in tests. * [fx] fix compatibility in tests. * [fx] fix import error.
This commit is contained in:
58
colossalai/fx/profiler/tensor.py
Normal file
58
colossalai/fx/profiler/tensor.py
Normal file
@@ -0,0 +1,58 @@
|
||||
import torch
|
||||
from torch.utils._pytree import tree_map, tree_flatten
|
||||
|
||||
__all__ = ['MetaTensor']
|
||||
|
||||
|
||||
class MetaTensor(torch.Tensor):
|
||||
"""
|
||||
A wrapping tensor that hacks `torch.autograd` without patching more `torch.ops.aten` ops.
|
||||
"""
|
||||
|
||||
_tensor: torch.Tensor
|
||||
|
||||
__slots__ = ['_tensor']
|
||||
|
||||
@staticmethod
|
||||
def __new__(cls, elem):
|
||||
# The wrapping tensor (MetaTensor) shouldn't hold any
|
||||
# memory for the class in question, but it should still
|
||||
# advertise the same device as before
|
||||
r = torch.Tensor._make_wrapper_subclass(
|
||||
cls,
|
||||
elem.size(),
|
||||
strides=elem.stride(),
|
||||
storage_offset=elem.storage_offset(),
|
||||
dtype=elem.dtype,
|
||||
layout=elem.layout,
|
||||
device='cpu',
|
||||
requires_grad=elem.requires_grad) # deceive the frontend for aten selections
|
||||
r._tensor = elem
|
||||
# ...the real tensor is held as an element on the tensor.
|
||||
return r
|
||||
|
||||
def __repr__(self):
|
||||
if self.grad_fn:
|
||||
return f"MetaTensor({self._tensor}, grad_fn={self.grad_fn})"
|
||||
return f"MetaTensor({self._tensor})"
|
||||
|
||||
@classmethod
|
||||
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
||||
|
||||
def unwrap(x):
|
||||
if isinstance(x, torch.Tensor) and not hasattr(x, '_tensor'):
|
||||
x = MetaTensor(x)
|
||||
return x._tensor.to('meta') if isinstance(x, MetaTensor) else x
|
||||
|
||||
args = tree_map(unwrap, args)
|
||||
kwargs = tree_map(unwrap, kwargs)
|
||||
|
||||
# run aten for backend=CPU but actually on backend=Meta
|
||||
out = func(*args, **kwargs)
|
||||
|
||||
# Now, we want to continue propagating this tensor, so we rewrap Tensors in
|
||||
# our custom tensor subclass
|
||||
def wrap(x):
|
||||
return MetaTensor(x) if isinstance(x, torch.Tensor) else x
|
||||
|
||||
return tree_map(wrap, out)
|
Reference in New Issue
Block a user