mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-12 20:54:35 +00:00
[autoparallel] move ckpt solvers to autoparallel folder / refactor code (#1764)
* [autoparallel] first move. * [autoparallel] add solver rotor. * [autoparallel] add ckpt solvers. * [autoparallel] modify codegen. * [fx] fix annotation in test. * [fx] remove check. * [autoparallel] polish docstring. * [fx] refactor MetaTensor.
This commit is contained in:
@@ -232,12 +232,12 @@ def _profile_meta(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], G
|
||||
|
||||
def pack(x):
|
||||
global cache, do_not_cache
|
||||
if isinstance(x, FlopTensor) and not x._tensor.uuid in cache:
|
||||
if isinstance(x, FlopTensor) and not x._tensor.data_ptr() in cache:
|
||||
tensor = x._tensor.detach()
|
||||
tensor.uuid = x._tensor.uuid
|
||||
tensor.data_ptr = x._tensor.data_ptr
|
||||
x._node.meta['saved_tensor'] += [tensor]
|
||||
if not do_not_cache:
|
||||
cache.add(x._tensor.uuid)
|
||||
cache.add(x._tensor.data_ptr())
|
||||
return x
|
||||
|
||||
def unpack(x):
|
||||
@@ -270,7 +270,7 @@ def _profile_meta(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], G
|
||||
def extract_tensor(x: Any):
|
||||
if isinstance(x, MetaTensor):
|
||||
tensor = x._tensor.detach()
|
||||
tensor.uuid = x._tensor.uuid
|
||||
tensor.data_ptr = x._tensor.data_ptr
|
||||
return tensor
|
||||
if not isinstance(x, torch.finfo):
|
||||
return x
|
||||
|
Reference in New Issue
Block a user