mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-25 11:44:03 +00:00
[autoparallel] move ckpt solvers to autoparallel folder / refactor code (#1764)
* [autoparallel] first move. * [autoparallel] add solver rotor. * [autoparallel] add ckpt solvers. * [autoparallel] modify codegen. * [fx] fix annotation in test. * [fx] remove check. * [autoparallel] polish docstring. * [fx] refactor MetaTensor.
This commit is contained in:
@@ -13,10 +13,10 @@ def activation_size(out: Union[torch.Tensor, Dict, List, Tuple, int]) -> int:
|
||||
"""Calculate activation size of a node.
|
||||
|
||||
Args:
|
||||
activation (Union[torch.Tensor, Dict, List, Tuple, int]): The activation of a `torch.nn.Module` or `torch.nn.functional`
|
||||
activation (Union[torch.Tensor, Dict, List, Tuple, int]): The activation of a `torch.nn.Module` or `torch.nn.functional`.
|
||||
|
||||
Returns:
|
||||
int: The activation size
|
||||
int: The activation size, unit is byte.
|
||||
"""
|
||||
act_size = 0
|
||||
if isinstance(out, torch.Tensor):
|
||||
@@ -38,10 +38,10 @@ def parameter_size(mod: torch.nn.Module) -> int:
|
||||
"""Calculate parameter size of a node.
|
||||
|
||||
Args:
|
||||
mod (torch.nn.Module): The target `torch.nn.Module`
|
||||
mod (torch.nn.Module): The target `torch.nn.Module`.
|
||||
|
||||
Returns:
|
||||
int: The parameter size
|
||||
int: The parameter size, unit is byte.
|
||||
"""
|
||||
param_size = 0
|
||||
for param in mod.parameters():
|
||||
|
@@ -232,12 +232,12 @@ def _profile_meta(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], G
|
||||
|
||||
def pack(x):
|
||||
global cache, do_not_cache
|
||||
if isinstance(x, FlopTensor) and not x._tensor.uuid in cache:
|
||||
if isinstance(x, FlopTensor) and not x._tensor.data_ptr() in cache:
|
||||
tensor = x._tensor.detach()
|
||||
tensor.uuid = x._tensor.uuid
|
||||
tensor.data_ptr = x._tensor.data_ptr
|
||||
x._node.meta['saved_tensor'] += [tensor]
|
||||
if not do_not_cache:
|
||||
cache.add(x._tensor.uuid)
|
||||
cache.add(x._tensor.data_ptr())
|
||||
return x
|
||||
|
||||
def unpack(x):
|
||||
@@ -270,7 +270,7 @@ def _profile_meta(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], G
|
||||
def extract_tensor(x: Any):
|
||||
if isinstance(x, MetaTensor):
|
||||
tensor = x._tensor.detach()
|
||||
tensor.uuid = x._tensor.uuid
|
||||
tensor.data_ptr = x._tensor.data_ptr
|
||||
return tensor
|
||||
if not isinstance(x, torch.finfo):
|
||||
return x
|
||||
|
@@ -87,8 +87,8 @@ def calculate_fwd_out(n: Node) -> int:
|
||||
|
||||
fwd_in = dict()
|
||||
for u in n.users:
|
||||
fwd_in.update({x.uuid: x for x in u.meta["fwd_in"] if isinstance(x, torch.Tensor) and hasattr(x, 'uuid')})
|
||||
fwd_out = {x.uuid: x for x in n.meta["fwd_out"] if isinstance(x, torch.Tensor) and hasattr(x, 'uuid')}
|
||||
fwd_in.update({x.data_ptr(): x for x in u.meta["fwd_in"] if isinstance(x, torch.Tensor)})
|
||||
fwd_out = {x.data_ptr(): x for x in n.meta["fwd_out"] if isinstance(x, torch.Tensor)}
|
||||
return activation_size(intersect(fwd_in, fwd_out))
|
||||
|
||||
|
||||
|
@@ -12,10 +12,11 @@ from .constants import ALIAS_ATEN
|
||||
__all__ = ['MetaTensor']
|
||||
|
||||
|
||||
def set_uuid(x):
|
||||
def set_data_ptr(x):
|
||||
if isinstance(x, torch.Tensor):
|
||||
if not hasattr(x, 'uuid'):
|
||||
setattr(x, 'uuid', uuid.uuid4())
|
||||
if not x.data_ptr():
|
||||
data_ptr = uuid.uuid4()
|
||||
x.data_ptr = lambda: data_ptr
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
@@ -53,7 +54,7 @@ class MetaTensor(torch.Tensor):
|
||||
if not r._tensor.is_meta:
|
||||
r._tensor = r._tensor.to(torch.device('meta'))
|
||||
# only tensor not on `meta` should be copied to `meta`
|
||||
set_uuid(r._tensor)
|
||||
set_data_ptr(r._tensor)
|
||||
return r
|
||||
|
||||
def __repr__(self):
|
||||
@@ -88,7 +89,7 @@ class MetaTensor(torch.Tensor):
|
||||
# here we keep the uuid of input because ALIAS_ATEN do not generate a physical copy
|
||||
# of the input
|
||||
if func in ALIAS_ATEN:
|
||||
setattr(out, 'uuid', args[0].uuid)
|
||||
out.data_ptr = args[0].data_ptr
|
||||
|
||||
# Now, we want to continue propagating this tensor, so we rewrap Tensors in
|
||||
# our custom tensor subclass
|
||||
|
Reference in New Issue
Block a user