[fx/profiler] tuned the calculation of memory estimation (#1619)

* [fx] tuned the meta info and rotor solver.

* [fx] remove import.

* [fx] remove import.

* [fx] remove import.

* [fx] tune the meta calculations.

* [fx] polish comments.

* [fx] remove assertions.

* [fx] modify test cases.

* [fx] modify test cases.

* [fx] optimize import.

* [fx
This commit is contained in:
Super Daniel
2022-09-23 10:59:47 +08:00
committed by GitHub
parent f7f2248771
commit d967779a32
16 changed files with 413 additions and 207 deletions

View File

@@ -1,10 +1,21 @@
from colossalai.fx.profiler.memory import activation_size
import torch
from torch.fx import Node, Graph
from torch.fx.graph import _Namespace
from torch.utils._pytree import tree_map
def meta_trace(module: torch.nn.Module, *args, **kwargs) -> Graph:
def normalize_tuple(x):
if not isinstance(x, tuple):
return (x,)
return x
def is_autogradable(x):
return isinstance(x, torch.Tensor) and x.is_floating_point()
def meta_trace(module: torch.nn.Module, fake_device=None, *args, **kwargs) -> Graph:
"""Trace forward and backward graph with MetaTensor
Args:
@@ -33,7 +44,7 @@ def meta_trace(module: torch.nn.Module, *args, **kwargs) -> Graph:
__slots__ = ['_tensor', '_node']
@staticmethod
def __new__(cls, tensor, placeholder=False, name=None):
def __new__(cls, tensor, fake_device=None, placeholder=False, name=None):
r = torch.Tensor._make_wrapper_subclass(
cls,
tensor.size(),
@@ -41,7 +52,7 @@ def meta_trace(module: torch.nn.Module, *args, **kwargs) -> Graph:
storage_offset=tensor.storage_offset(),
dtype=tensor.dtype,
layout=tensor.layout,
device='cpu',
device=fake_device if fake_device is not None else tensor.device,
requires_grad=tensor.requires_grad) # deceive the frontend for aten selections
r._tensor = tensor
if placeholder:
@@ -51,15 +62,23 @@ def meta_trace(module: torch.nn.Module, *args, **kwargs) -> Graph:
'placeholder', (graph._root,),
name=namespace.create_name(name, tensor))
# ...the real tensor is held as an element on the tensor.
if not r._tensor.is_meta:
r._tensor = r._tensor.to(torch.device('meta'))
return r
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
def unwrap(x):
if isinstance(x, torch.Tensor) and not hasattr(x, '_tensor'):
x = MetaProxy(x)
return x._tensor.to('meta') if isinstance(x, MetaProxy) else x
nonlocal fake_device
if isinstance(x, MetaProxy):
fake_device = x.device
x = x._tensor
# assert not isinstance(x, MetaProxy)
elif isinstance(x, torch.Tensor):
fake_device = x.device
x = x.to(torch.device('meta'))
return x
def get_node(x):
if isinstance(x, torch.Tensor) and not hasattr(x, '_node'):
@@ -70,6 +89,10 @@ def meta_trace(module: torch.nn.Module, *args, **kwargs) -> Graph:
kwargs_node = tree_map(get_node, kwargs)
node = graph.create_node('call_function', func, args_node, kwargs_node)
if 'device' in kwargs:
fake_device = kwargs['device']
kwargs['device'] = torch.device('meta')
args = tree_map(unwrap, args)
kwargs = tree_map(unwrap, kwargs)
@@ -79,7 +102,12 @@ def meta_trace(module: torch.nn.Module, *args, **kwargs) -> Graph:
# Now, we want to continue propagating this tensor, so we rewrap Tensors in
# our custom tensor subclass
def wrap(x):
return MetaProxy(x) if isinstance(x, torch.Tensor) and not hasattr(x, '_tensor') else x
if isinstance(x, torch.Tensor):
nonlocal fake_device
if not x.is_meta:
x = x.to(torch.device('meta'))
return MetaProxy(
x, fake_device=fake_device) if isinstance(x, torch.Tensor) and not hasattr(x, '_tensor') else x
def set_node(x):
x._node = node
@@ -90,10 +118,18 @@ def meta_trace(module: torch.nn.Module, *args, **kwargs) -> Graph:
return out
def wrap(x):
return MetaProxy(x, True) if isinstance(x, torch.Tensor) else x
return MetaProxy(x, fake_device=fake_device, placeholder=True) if isinstance(x, torch.Tensor) else x
args = tree_map(wrap, args)
kwargs = tree_map(wrap, kwargs)
module(*args, **kwargs).sum().backward()
out = module(*args, **kwargs)
for tensor in normalize_tuple(out):
if is_autogradable(tensor) and tensor.requires_grad:
grad = torch.empty_like(tensor._tensor, device=torch.device('meta')) if isinstance(
tensor, MetaProxy) else torch.empty_like(tensor, device=torch.device('meta'))
torch.autograd.backward(tensor,
MetaProxy(grad, fake_device=tensor.device, placeholder=True),
retain_graph=True)
return graph