[autockpt] considering parameter and optimizer weights. (#2279)

* [autockpt] make it work.

* [autockpt] linearize / merge shape-consistency nodes.

* [autockpt] considering parameter and optimizer weights.
This commit is contained in:
Super Daniel
2023-01-03 16:55:49 +08:00
committed by GitHub
parent b0d21d0c4f
commit 8e8900ff3f
3 changed files with 30 additions and 19 deletions

View File

@@ -35,10 +35,11 @@ class CheckpointSolverBase(ABC):
free_memory: float = -1.0,
requires_linearize: bool = False,
cnode: List[str] = None,
optim_multiplier: float = 1.0,
):
"""CheckpointSolver class will integrate information provided by the components
and use an existing solver to find a possible optimal strategies combination for
target computing graph.
"""``CheckpointSolverBase`` class will integrate information provided by the components
and use an existing solver to find a possible optimal strategies combination for target
computing graph.
Existing Solvers:
Chen's Greedy solver: https://arxiv.org/abs/1604.06174 (CheckpointSolverChen)
@@ -49,9 +50,11 @@ class CheckpointSolverBase(ABC):
free_memory (float): Memory constraint for the solution.
requires_linearize (bool): Whether the graph needs to be linearized.
cnode (List[str], optional): Common node List, should be the subset of input. Default to None.
optim_multiplier (float, optional): The multiplier of extra weight storage for the
``torch.optim.Optimizer``. Default to 1.0.
Warnings:
`MetaInfoProp` should be done before constructing the solver. Meta information of the graph is required.
Meta information of the graph is required for any ``CheckpointSolver``.
"""
# super-dainiu: this graph is a temporary graph which can refer to
# the owning module, but we will return another deepcopy of it after
@@ -61,13 +64,14 @@ class CheckpointSolverBase(ABC):
_copy_output(graph, self.graph)
self.graph.set_codegen(ActivationCheckpointCodeGen())
# check if `MetaInfoProp` is done
# check if has meta information
if any(len(node.meta) == 0 for node in self.graph.nodes):
raise RuntimeError(
"Nodes meta information hasn't been prepared! Please run MetaInfoProp before constructing the solver!")
"Nodes meta information hasn't been prepared! Please extract from graph before constructing the solver!"
)
self.free_memory = free_memory
self.parameter_size = _get_param_size(self.graph.owning_module)
# parameter memory = parameter size + optimizer extra weight storage
self.free_memory = free_memory - _get_param_size(self.graph.owning_module) * (optim_multiplier + 1)
self.cnode = cnode
self.requires_linearize = requires_linearize
if self.requires_linearize:
@@ -97,7 +101,7 @@ class CheckpointSolverBase(ABC):
the actual 'node' in linearized manner.
Remarks:
Do merge the inplace ops into the previous node.
Do merge the inplace ops and shape-consistency ops into the previous node.
"""
# Common nodes are type of nodes that could be seen as attributes and remain
@@ -136,7 +140,7 @@ class CheckpointSolverBase(ABC):
"""
def _is_inplace(n: Node):
"""Get the inplace argument from torch.fx.Node
"""Get the inplace argument from ``torch.fx.Node``
"""
inplace = False
if n.op == "call_function":