mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 04:24:47 +00:00
[autoparallel] user-friendly API for CheckpointSolver. (#1879)
Merge for SC tutorial
This commit is contained in:
@@ -22,12 +22,7 @@ __all__ = ['CheckpointSolverRotor']
|
||||
|
||||
class CheckpointSolverRotor(CheckpointSolverBase):
|
||||
|
||||
def __init__(self,
|
||||
graph: Graph,
|
||||
memory_budget: float = -1,
|
||||
parameter_size: float = 0,
|
||||
cnode: List[str] = None,
|
||||
memory_slots: int = 500):
|
||||
def __init__(self, graph: Graph, free_memory: float = -1, cnode: List[str] = None, memory_slots: int = 500):
|
||||
"""This is the simple implementation of dynamic programming algorithm rotor
|
||||
in https://hal.inria.fr/hal-02352969. Some code are adapted from
|
||||
https://gitlab.inria.fr/hiepacs/rotor.
|
||||
@@ -36,22 +31,22 @@ class CheckpointSolverRotor(CheckpointSolverBase):
|
||||
Assume that we have a `GraphModule`, and we already applied the `MetaInfoProp`
|
||||
to the graph to retrieve all information needed, then we could use the following
|
||||
code to find a solution using `CheckpointSolverRotor`:
|
||||
>>> solver = CheckpointSolverRotor(gm.graph, memory_budget=memory_budget, parameter_size=parameter_size)
|
||||
>>> solver = CheckpointSolverRotor(gm.graph, free_memory=torch.cuda.mem_get_info(device=0)[0])
|
||||
>>> rotor_graph = solver.solve(force_python=True) # otherwise use C solver
|
||||
>>> gm.graph = rotor_graph # set the graph to a new graph
|
||||
|
||||
Args:
|
||||
graph (Graph): The computing graph to be optimized.
|
||||
memory_budget (float, optional): Memory constraint for the solution, unit is byte.
|
||||
parameter_size (float, optional): The size of parameter of this model, unit is byte. Use `parameter_size(model)` to estimate.
|
||||
free_memory (float, optional): Memory constraint for the solution, unit is byte.
|
||||
Use ``torch.cuda.mem_get_info(device=0)[0]`` to estimate the free_memory. Defaults to -1.
|
||||
cnode (List[str], optional): Common node List, should be the subset of input. Defaults to None.
|
||||
memory_slots (int, optional): Number of slots for discretizing memory budget. Defaults to 500.
|
||||
"""
|
||||
super().__init__(graph, memory_budget, parameter_size, True, cnode)
|
||||
super().__init__(graph, free_memory, True, cnode)
|
||||
self.memory_slots = memory_slots
|
||||
|
||||
# construct chain
|
||||
unit = self.memory_budget // self.memory_slots
|
||||
unit = self.free_memory // self.memory_slots
|
||||
self.chain = self._construct_chain(self.graph, self.node_list)
|
||||
self.chain.discretize_all(unit)
|
||||
|
||||
|
Reference in New Issue
Block a user