[autoparallel] move ckpt solvers to autoparallel folder / refactor code (#1764)

* [autoparallel] first move.

* [autoparallel] add solver rotor.

* [autoparallel] add ckpt solvers.

* [autoparallel] modify codegen.

* [fx] fix annotation in test.

* [fx] remove check.

* [autoparallel] polish docstring.

* [fx] refactor MetaTensor.
This commit is contained in:
Super Daniel
2022-11-01 10:43:15 +08:00
committed by GitHub
parent 2b859502d5
commit 1e88811c7a
16 changed files with 1025 additions and 119 deletions

View File

@@ -232,12 +232,12 @@ def _profile_meta(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], G
def pack(x):
global cache, do_not_cache
if isinstance(x, FlopTensor) and not x._tensor.uuid in cache:
if isinstance(x, FlopTensor) and not x._tensor.data_ptr() in cache:
tensor = x._tensor.detach()
tensor.uuid = x._tensor.uuid
tensor.data_ptr = x._tensor.data_ptr
x._node.meta['saved_tensor'] += [tensor]
if not do_not_cache:
cache.add(x._tensor.uuid)
cache.add(x._tensor.data_ptr())
return x
def unpack(x):
@@ -270,7 +270,7 @@ def _profile_meta(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], G
def extract_tensor(x: Any):
if isinstance(x, MetaTensor):
tensor = x._tensor.detach()
tensor.uuid = x._tensor.uuid
tensor.data_ptr = x._tensor.data_ptr
return tensor
if not isinstance(x, torch.finfo):
return x