[autoparallel] refactor and add rotorc. (#1789)

* [autoparallel] refactor and add rotorc.

* [autoparallel] refactor and add rotorc.
This commit is contained in:
Super Daniel
2022-11-03 12:32:51 +08:00
committed by GitHub
parent 4d6e1284cb
commit e8a9bebc87
5 changed files with 333 additions and 129 deletions

View File

@@ -1,5 +1,5 @@
from copy import deepcopy
from typing import Dict, List, Tuple
from typing import Any, Dict, List, Tuple
from torch import Tensor
from torch.fx import Graph, Node
@@ -15,9 +15,9 @@ from colossalai.fx.profiler import (
from colossalai.logging import get_dist_logger
from .ckpt_solver_base import CheckpointSolverBase
from .operation import Backward, Chain, ForwardCheck, ForwardEnable, ForwardNograd, Function, Loss, Sequence
from .operation import Backward, Chain, ForwardCheck, ForwardEnable, ForwardNograd, Loss, Sequence
__all__ = ['CheckpointSolverBase']
__all__ = ['CheckpointSolverRotor']
class CheckpointSolverRotor(CheckpointSolverBase):
@@ -59,11 +59,12 @@ class CheckpointSolverRotor(CheckpointSolverBase):
self.back_ptr = None
self.sequence = None
def solve(self, force_python: bool = False) -> Graph:
def solve(self, force_python: bool = False, verbose: bool = False) -> Graph:
"""Solve the checkpointing problem using rotor algorithm.
Args:
force_python (bool, optional): Use Python version of solver, else use C version. Defaults to False.
verbose (bool, optional): Print verbose information. Defaults to False.
Returns:
graph (Graph): The optimized graph, should be a copy of the original graph.
@@ -76,14 +77,22 @@ class CheckpointSolverRotor(CheckpointSolverBase):
else:
self.cost_table, self.back_ptr = self._compute_table_c(chain, self.memory_slots)
if verbose:
self.print_chain()
# backtrack
try:
self.sequence = self._backtrack(chain, 0, chain.length, self.memory_slots, self.cost_table, self.back_ptr)
self.sequence = self._backtrack(chain, 0, len(chain), self.memory_slots - chain.x[0], self.cost_table,
self.back_ptr)
self._annotate_from_sequence(self.sequence, self.node_list)
except RuntimeError as e:
except ValueError as e:
# using logger to annonce that the solver is failed
logger = get_dist_logger()
logger.warning(f'Checkpoint solver failed: {e}')
raise ValueError
if verbose:
self.print_sequence()
return deepcopy(self.graph)
@@ -100,42 +109,42 @@ class CheckpointSolverRotor(CheckpointSolverBase):
@classmethod
def _construct_chain(cls, graph: Graph, node_list: List[List[Node]]) -> Chain:
input_tensors = cls._extract_input(graph)
fwd_time, bwd_time, ftmp, btmp = list(), list(), list(), list()
ftime, btime, ftmp, btmp = list(), list(), list(), list()
xbar, x = [activation_size(input_tensors)], [activation_size(input_tensors)]
for idx, node in enumerate(node_list):
for node in node_list:
node_info = cls._extract_node_info(node)
fwd_time.append(node_info[0])
bwd_time.append(node_info[1])
ftime.append(node_info[0])
btime.append(node_info[1])
x.append(node_info[2])
xbar.append(node_info[3])
ftmp.append(node_info[4])
btmp.append(node_info[5])
# currently we view loss backward temp as zero
bwd_time.append(0)
btime.append(0)
btmp.append(0)
return Chain(fwd_time, bwd_time, x, xbar, ftmp, btmp)
return Chain(ftime, btime, x, xbar, ftmp, btmp)
@classmethod
def _extract_node_info(cls, node: List[Node]) -> Tuple[int, ...]:
"""Extract node info from a list of nodes"""
xbar = 0
fwd_time = 0
bwd_time = 0
ftime = 0
btime = 0
for n in node:
assert isinstance(n, Node), f'{n} is not a Node'
xbar += calculate_fwd_tmp(n) + calculate_fwd_out(n)
# minimum flop count is required
fwd_time += max(calculate_fwd_time(n), 1.0)
bwd_time += max(calculate_bwd_time(n), 1.0)
ftime += max(calculate_fwd_time(n), 1.0)
btime += max(calculate_bwd_time(n), 1.0)
x = calculate_fwd_out(node[-1])
xbar = max(x, xbar)
ftmp = cls._extract_ftmp(node)
btmp = cls._extract_btmp(node)
return fwd_time, bwd_time, x, xbar, ftmp, btmp
return ftime, btime, x, xbar, ftmp, btmp
@staticmethod
def _extract_input(graph: Graph) -> Tuple[Tensor, ...]:
@@ -180,17 +189,17 @@ class CheckpointSolverRotor(CheckpointSolverBase):
return btmp
@staticmethod
def _compute_table(chain: Chain, mem_slots: int) -> Tuple:
def _compute_table(chain: Chain, mmax: int) -> Tuple:
"""Compute the table using dynamic programming. Returns the cost table and the backtracking pointer.
Args:
chain (Chain): A basic linearized structure for solving the dynamic programming problem.
mem_slots (int): Number of slots for discretizing memory budget.
mmax (int): Maximum number of memory slots.
Returns:
cost_table (List[List[Dict[int, Tuple]]]): cost_table[m][lmin][lmax] with lmin = 0...chain.length
and lmax = lmin...chain.length (lmax is not included) and m = 0...mmax
back_ptr (List[List[Dict[int, Tuple]]]): back_ptr[m][lmin][lmax] is (True,) if the optimal choice
cost_table (List): cost_table[m][lhs][rhs] with lhs = 0...chain.length
and rhs = lhs...chain.length (lhs is not included) and m = 0...mmax
back_ptr (List): back_ptr[m][lhs][rhs] is (True,) if the optimal choice
is a chain checkpoint (False, j) if the optimal choice is a leaf checkpoint
of length j
"""
@@ -203,13 +212,13 @@ class CheckpointSolverRotor(CheckpointSolverBase):
btmp = chain.btmp + [0]
# Build table
cost_table = [[{} for _ in range(chain.length + 1)] for _ in range(mem_slots + 1)]
back_ptr = [[{} for _ in range(chain.length + 1)] for _ in range(mem_slots + 1)]
cost_table = [[{} for _ in range(len(chain) + 1)] for _ in range(mmax + 1)]
back_ptr = [[{} for _ in range(len(chain) + 1)] for _ in range(mmax + 1)]
# Last one is a dict because its indices go from i to l. Renumbering will wait for C implementation
# Initialize borders of the tables for lmax-lmin = 0
for m in range(mem_slots + 1):
for i in range(chain.length + 1):
for m in range(mmax + 1):
for i in range(len(chain) + 1):
limit = max(x[i + 1] + xbar[i + 1] + ftmp[i], x[i + 1] + xbar[i + 1] + btmp[i])
if m >= limit: # Equation (1)
cost_table[m][i][i] = ftime[i] + btime[i]
@@ -217,9 +226,9 @@ class CheckpointSolverRotor(CheckpointSolverBase):
cost_table[m][i][i] = float("inf")
# Compute everything
for m in range(mem_slots + 1):
for d in range(1, chain.length + 1):
for i in range(chain.length + 1 - d):
for m in range(mmax + 1):
for d in range(1, len(chain) + 1):
for i in range(len(chain) + 1 - d):
idx = i + d
mmin = x[idx + 1] + x[i + 1] + ftmp[i]
if idx > i + 1:
@@ -248,20 +257,46 @@ class CheckpointSolverRotor(CheckpointSolverBase):
return cost_table, back_ptr
@staticmethod
def _compute_table_c(chain: Chain, mem_slots: int) -> Tuple:
raise NotImplementedError("C implementation not available yet")
def _compute_table_c(chain: Chain, mmax: int) -> Tuple:
try:
from .rotorc import compute_table
def _backtrack(self, chain: Chain, lmin: int, lmax: int, mem_budget: int, cost_table: List[List[Dict[int, Tuple]]],
back_ptr: List[List[Dict[int, int]]]) -> List[int]:
# build module if module not found
except ModuleNotFoundError:
import os
import subprocess
import sys
logger = get_dist_logger()
logger.info("rotorc hasn't been built! Building library...", ranks=[0])
this_dir = os.path.dirname(os.path.abspath(__file__))
result = subprocess.Popen(
[
f"{sys.executable}", f"{os.path.join(this_dir, 'build_c_ext.py')}", "build_ext",
f"--build-lib={this_dir}"
],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
if result.wait() == 0:
logger.info("rotorc has been built!", ranks=[0])
from .rotorc import compute_table
else:
logger.warning("rotorc built failed! Using python version!", ranks=[0])
return CheckpointSolverRotor._compute_table(chain, mmax)
return compute_table(chain, mmax)
@staticmethod
def _backtrack(chain: Chain, lhs: int, rhs: int, budget: int, cost_table: List[Any],
back_ptr: List[Any]) -> "Sequence":
"""Backtrack the cost table and retrieve the optimal checkpointing strategy.
Args:
chain (Chain): A basic linearized structure for solving the dynamic programming problem.
lmin (int): The left index of the interval to backtrack.
lmax (int): The right index of the interval to backtrack.
mem_budget (int): The memory budget for processing this interval.
cost_table (List[List[Dict[int, Tuple]]]): See _compute_table() for definitions
back_ptr (List[List[Dict[int, Tuple]]]): See _compute_table() for definitions
lhs (int): The left index of the interval to backtrack.
rhs (int): The right index of the interval to backtrack.
budget (int): The memory budget for processing this interval.
cost_table (List[Any]): See `._compute_table()` for definitions
back_ptr (List[Any]): See `._compute_table()` for definitions
Raises:
ValueError: Can not process the chain.
@@ -269,36 +304,45 @@ class CheckpointSolverRotor(CheckpointSolverBase):
Returns:
sequence (Sequence): The sequence of executing nodes with checkpoints.
"""
if mem_budget <= 0:
raise ValueError(f"Can not process a chain with negative memory {mem_budget}")
elif cost_table[mem_budget][lmin][lmax] == float("inf"):
raise ValueError(f"Can not process this chain from index {lmin} to {lmax} with memory {mem_budget}")
if budget <= 0:
raise ValueError(f"Can not process a chain with negative memory {budget}")
elif cost_table[budget][lhs][rhs] == float("inf"):
raise ValueError(f"Can not process this chain from index {lhs} to {rhs} with memory {budget}")
sequence = Sequence(Function("Persistent", lmax - lmin, mem_budget))
if lmin == lmax:
if lmin == chain.length:
sequence.insert(Loss())
sequence = Sequence()
if rhs == lhs:
if lhs == len(chain):
sequence += [Loss()]
else:
sequence.insert(ForwardEnable(lmin))
sequence.insert(Backward(lmin))
sequence += [ForwardEnable(lhs), Backward(lhs)]
return sequence
if back_ptr[mem_budget][lmin][lmax][0]:
sequence.insert(ForwardEnable(lmin))
sequence.insert_sequence(
self._backtrack(chain, lmin + 1, lmax, mem_budget - chain.xbar[lmin + 1], cost_table, back_ptr))
sequence.insert(Backward(lmin))
if back_ptr[budget][lhs][rhs][0]:
sequence += [
ForwardEnable(lhs),
CheckpointSolverRotor._backtrack(chain, lhs + 1, rhs, budget - chain.xbar[lhs + 1], cost_table,
back_ptr),
Backward(lhs),
]
else:
j = back_ptr[mem_budget][lmin][lmax][1]
sequence.insert(ForwardCheck(lmin))
for k in range(lmin + 1, j):
sequence.insert(ForwardNograd(k))
sequence.insert_sequence(self._backtrack(chain, j, lmax, mem_budget - chain.xbar[j], cost_table, back_ptr))
sequence.insert_sequence(self._backtrack(chain, lmin, j - 1, mem_budget, cost_table, back_ptr))
best_leaf = back_ptr[budget][lhs][rhs][1]
sequence += [ForwardCheck(lhs)]
sequence += [ForwardNograd(k) for k in range(lhs + 1, best_leaf)]
sequence += [
CheckpointSolverRotor._backtrack(chain, best_leaf, rhs, budget - chain.x[best_leaf], cost_table,
back_ptr),
CheckpointSolverRotor._backtrack(chain, lhs, best_leaf - 1, budget, cost_table, back_ptr),
]
return sequence
@staticmethod
def _annotate_from_sequence(sequence: Sequence, node_list: List[List[Node]]):
"""Annotate the nodes in the node_list with activation checkpoint from the sequence.
Args:
sequence (Sequence): The sequence of executing nodes with activation checkpoint annotations.
node_list (List[List[Node]]): The list of nodes to annotate.
"""
op_list = sequence.list_operations()
loss_op = next(op for op in op_list if isinstance(op, Loss))
fwd_list = op_list[:op_list.index(loss_op)]