[autoparallel] move ckpt solvers to autoparallel folder / refactor code (#1764)

* [autoparallel] first move.

* [autoparallel] add solver rotor.

* [autoparallel] add ckpt solvers.

* [autoparallel] modify codegen.

* [fx] fix annotation in test.

* [fx] remove check.

* [autoparallel] polish docstring.

* [fx] refactor MetaTensor.
This commit is contained in:
Super Daniel
2022-11-01 10:43:15 +08:00
committed by GitHub
parent 2b859502d5
commit 1e88811c7a
16 changed files with 1025 additions and 119 deletions

View File

@@ -13,10 +13,10 @@ def activation_size(out: Union[torch.Tensor, Dict, List, Tuple, int]) -> int:
"""Calculate activation size of a node.
Args:
activation (Union[torch.Tensor, Dict, List, Tuple, int]): The activation of a `torch.nn.Module` or `torch.nn.functional`
activation (Union[torch.Tensor, Dict, List, Tuple, int]): The activation of a `torch.nn.Module` or `torch.nn.functional`.
Returns:
int: The activation size
int: The activation size, unit is byte.
"""
act_size = 0
if isinstance(out, torch.Tensor):
@@ -38,10 +38,10 @@ def parameter_size(mod: torch.nn.Module) -> int:
"""Calculate parameter size of a node.
Args:
mod (torch.nn.Module): The target `torch.nn.Module`
mod (torch.nn.Module): The target `torch.nn.Module`.
Returns:
int: The parameter size
int: The parameter size, unit is byte.
"""
param_size = 0
for param in mod.parameters():

View File

@@ -232,12 +232,12 @@ def _profile_meta(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], G
def pack(x):
global cache, do_not_cache
if isinstance(x, FlopTensor) and not x._tensor.uuid in cache:
if isinstance(x, FlopTensor) and not x._tensor.data_ptr() in cache:
tensor = x._tensor.detach()
tensor.uuid = x._tensor.uuid
tensor.data_ptr = x._tensor.data_ptr
x._node.meta['saved_tensor'] += [tensor]
if not do_not_cache:
cache.add(x._tensor.uuid)
cache.add(x._tensor.data_ptr())
return x
def unpack(x):
@@ -270,7 +270,7 @@ def _profile_meta(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], G
def extract_tensor(x: Any):
if isinstance(x, MetaTensor):
tensor = x._tensor.detach()
tensor.uuid = x._tensor.uuid
tensor.data_ptr = x._tensor.data_ptr
return tensor
if not isinstance(x, torch.finfo):
return x

View File

@@ -87,8 +87,8 @@ def calculate_fwd_out(n: Node) -> int:
fwd_in = dict()
for u in n.users:
fwd_in.update({x.uuid: x for x in u.meta["fwd_in"] if isinstance(x, torch.Tensor) and hasattr(x, 'uuid')})
fwd_out = {x.uuid: x for x in n.meta["fwd_out"] if isinstance(x, torch.Tensor) and hasattr(x, 'uuid')}
fwd_in.update({x.data_ptr(): x for x in u.meta["fwd_in"] if isinstance(x, torch.Tensor)})
fwd_out = {x.data_ptr(): x for x in n.meta["fwd_out"] if isinstance(x, torch.Tensor)}
return activation_size(intersect(fwd_in, fwd_out))

View File

@@ -12,10 +12,11 @@ from .constants import ALIAS_ATEN
__all__ = ['MetaTensor']
def set_uuid(x):
def set_data_ptr(x):
if isinstance(x, torch.Tensor):
if not hasattr(x, 'uuid'):
setattr(x, 'uuid', uuid.uuid4())
if not x.data_ptr():
data_ptr = uuid.uuid4()
x.data_ptr = lambda: data_ptr
@compatibility(is_backward_compatible=False)
@@ -53,7 +54,7 @@ class MetaTensor(torch.Tensor):
if not r._tensor.is_meta:
r._tensor = r._tensor.to(torch.device('meta'))
# only tensor not on `meta` should be copied to `meta`
set_uuid(r._tensor)
set_data_ptr(r._tensor)
return r
def __repr__(self):
@@ -88,7 +89,7 @@ class MetaTensor(torch.Tensor):
# here we keep the uuid of input because ALIAS_ATEN do not generate a physical copy
# of the input
if func in ALIAS_ATEN:
setattr(out, 'uuid', args[0].uuid)
out.data_ptr = args[0].data_ptr
# Now, we want to continue propagating this tensor, so we rewrap Tensors in
# our custom tensor subclass