mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-14 21:51:57 +00:00
[fx/profiler] tuned the calculation of memory estimation (#1619)
* [fx] tuned the meta info and rotor solver. * [fx] remove import. * [fx] remove import. * [fx] remove import. * [fx] tune the meta calculations. * [fx] polish comments. * [fx] remove assertions. * [fx] modify test cases. * [fx] modify test cases. * [fx] optimize import. * [fx
This commit is contained in:
@@ -1,10 +1,21 @@
|
||||
from colossalai.fx.profiler.memory import activation_size
|
||||
import torch
|
||||
from torch.fx import Node, Graph
|
||||
from torch.fx.graph import _Namespace
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
|
||||
def meta_trace(module: torch.nn.Module, *args, **kwargs) -> Graph:
|
||||
def normalize_tuple(x):
|
||||
if not isinstance(x, tuple):
|
||||
return (x,)
|
||||
return x
|
||||
|
||||
|
||||
def is_autogradable(x):
|
||||
return isinstance(x, torch.Tensor) and x.is_floating_point()
|
||||
|
||||
|
||||
def meta_trace(module: torch.nn.Module, fake_device=None, *args, **kwargs) -> Graph:
|
||||
"""Trace forward and backward graph with MetaTensor
|
||||
|
||||
Args:
|
||||
@@ -33,7 +44,7 @@ def meta_trace(module: torch.nn.Module, *args, **kwargs) -> Graph:
|
||||
__slots__ = ['_tensor', '_node']
|
||||
|
||||
@staticmethod
|
||||
def __new__(cls, tensor, placeholder=False, name=None):
|
||||
def __new__(cls, tensor, fake_device=None, placeholder=False, name=None):
|
||||
r = torch.Tensor._make_wrapper_subclass(
|
||||
cls,
|
||||
tensor.size(),
|
||||
@@ -41,7 +52,7 @@ def meta_trace(module: torch.nn.Module, *args, **kwargs) -> Graph:
|
||||
storage_offset=tensor.storage_offset(),
|
||||
dtype=tensor.dtype,
|
||||
layout=tensor.layout,
|
||||
device='cpu',
|
||||
device=fake_device if fake_device is not None else tensor.device,
|
||||
requires_grad=tensor.requires_grad) # deceive the frontend for aten selections
|
||||
r._tensor = tensor
|
||||
if placeholder:
|
||||
@@ -51,15 +62,23 @@ def meta_trace(module: torch.nn.Module, *args, **kwargs) -> Graph:
|
||||
'placeholder', (graph._root,),
|
||||
name=namespace.create_name(name, tensor))
|
||||
# ...the real tensor is held as an element on the tensor.
|
||||
if not r._tensor.is_meta:
|
||||
r._tensor = r._tensor.to(torch.device('meta'))
|
||||
return r
|
||||
|
||||
@classmethod
|
||||
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
||||
|
||||
def unwrap(x):
|
||||
if isinstance(x, torch.Tensor) and not hasattr(x, '_tensor'):
|
||||
x = MetaProxy(x)
|
||||
return x._tensor.to('meta') if isinstance(x, MetaProxy) else x
|
||||
nonlocal fake_device
|
||||
if isinstance(x, MetaProxy):
|
||||
fake_device = x.device
|
||||
x = x._tensor
|
||||
# assert not isinstance(x, MetaProxy)
|
||||
elif isinstance(x, torch.Tensor):
|
||||
fake_device = x.device
|
||||
x = x.to(torch.device('meta'))
|
||||
return x
|
||||
|
||||
def get_node(x):
|
||||
if isinstance(x, torch.Tensor) and not hasattr(x, '_node'):
|
||||
@@ -70,6 +89,10 @@ def meta_trace(module: torch.nn.Module, *args, **kwargs) -> Graph:
|
||||
kwargs_node = tree_map(get_node, kwargs)
|
||||
node = graph.create_node('call_function', func, args_node, kwargs_node)
|
||||
|
||||
if 'device' in kwargs:
|
||||
fake_device = kwargs['device']
|
||||
kwargs['device'] = torch.device('meta')
|
||||
|
||||
args = tree_map(unwrap, args)
|
||||
kwargs = tree_map(unwrap, kwargs)
|
||||
|
||||
@@ -79,7 +102,12 @@ def meta_trace(module: torch.nn.Module, *args, **kwargs) -> Graph:
|
||||
# Now, we want to continue propagating this tensor, so we rewrap Tensors in
|
||||
# our custom tensor subclass
|
||||
def wrap(x):
|
||||
return MetaProxy(x) if isinstance(x, torch.Tensor) and not hasattr(x, '_tensor') else x
|
||||
if isinstance(x, torch.Tensor):
|
||||
nonlocal fake_device
|
||||
if not x.is_meta:
|
||||
x = x.to(torch.device('meta'))
|
||||
return MetaProxy(
|
||||
x, fake_device=fake_device) if isinstance(x, torch.Tensor) and not hasattr(x, '_tensor') else x
|
||||
|
||||
def set_node(x):
|
||||
x._node = node
|
||||
@@ -90,10 +118,18 @@ def meta_trace(module: torch.nn.Module, *args, **kwargs) -> Graph:
|
||||
return out
|
||||
|
||||
def wrap(x):
|
||||
return MetaProxy(x, True) if isinstance(x, torch.Tensor) else x
|
||||
return MetaProxy(x, fake_device=fake_device, placeholder=True) if isinstance(x, torch.Tensor) else x
|
||||
|
||||
args = tree_map(wrap, args)
|
||||
kwargs = tree_map(wrap, kwargs)
|
||||
|
||||
module(*args, **kwargs).sum().backward()
|
||||
out = module(*args, **kwargs)
|
||||
|
||||
for tensor in normalize_tuple(out):
|
||||
if is_autogradable(tensor) and tensor.requires_grad:
|
||||
grad = torch.empty_like(tensor._tensor, device=torch.device('meta')) if isinstance(
|
||||
tensor, MetaProxy) else torch.empty_like(tensor, device=torch.device('meta'))
|
||||
torch.autograd.backward(tensor,
|
||||
MetaProxy(grad, fake_device=tensor.device, placeholder=True),
|
||||
retain_graph=True)
|
||||
return graph
|
||||
|
Reference in New Issue
Block a user