[autoparallel] move ckpt solvers to autoparallel folder / refactor code (#1764)

* [autoparallel] first move.

* [autoparallel] add solver rotor.

* [autoparallel] add ckpt solvers.

* [autoparallel] modify codegen.

* [fx] fix annotation in test.

* [fx] remove check.

* [autoparallel] polish docstring.

* [fx] refactor MetaTensor.
This commit is contained in:
Super Daniel
2022-11-01 10:43:15 +08:00
committed by GitHub
parent 2b859502d5
commit 1e88811c7a
16 changed files with 1025 additions and 119 deletions

View File

@@ -0,0 +1,3 @@
from .ckpt_solver_base import CheckpointSolverBase
from .ckpt_solver_chen import CheckpointSolverChen
from .ckpt_solver_rotor import CheckpointSolverRotor

View File

@@ -0,0 +1,167 @@
from abc import ABC, abstractmethod
from copy import deepcopy
from typing import Any, List
from torch.fx import Graph, Node
from colossalai.fx.codegen.activation_checkpoint_codegen import ActivationCheckpointCodeGen
from colossalai.fx.profiler.memory_utils import is_inplace
__all___ = ['CheckpointSolverBase']
def _copy_output(src: Graph, dst: Graph):
"""Copy the output node from src to dst"""
for n_src, n_dst in zip(src.nodes, dst.nodes):
if n_src.op == 'output':
n_dst.meta = n_src.meta
class CheckpointSolverBase(ABC):
def __init__(
self,
graph: Graph,
memory_budget: float = -1.0,
parameter_size: float = 0,
requires_linearize: bool = False,
cnode: List[str] = None,
):
"""CheckpointSolver class will integrate information provided by the components
and use an existing solver to find a possible optimal strategies combination for
target computing graph.
Existing Solvers:
Chen's Greedy solver: https://arxiv.org/abs/1604.06174 (CheckpointSolverChen)
Rotor solver: https://hal.inria.fr/hal-02352969 (CheckpointSolverRotor)
Args:
graph (Graph): The computing graph to be optimized.
memory_budget (float): Memory constraint for the solution.
parameter_size (float): The size of parameter of this model. Use `parameter_size(model)` to estimate.
requires_linearize (bool): Whether the graph needs to be linearized.
cnode (List[str], optional): Common node List, should be the subset of input. Default to None.
Warnings:
`MetaInfoProp` should be done before constructing the solver. Meta information of the graph is required.
"""
# super-dainiu: this graph is a temporary graph which can refer to
# the owning module, but we will return another deepcopy of it after
# the solver is executed.
self.graph = deepcopy(graph)
self.graph.owning_module = graph.owning_module
_copy_output(graph, self.graph)
self.graph.set_codegen(ActivationCheckpointCodeGen())
# check if `MetaInfoProp` is done
if any(len(node.meta) == 0 for node in self.graph.nodes):
raise RuntimeError(
"Nodes meta information hasn't been prepared! Please run MetaInfoProp before constructing the solver!")
self.memory_budget = memory_budget
self.parameter_size = parameter_size
self.cnode = cnode
self.requires_linearize = requires_linearize
if self.requires_linearize:
self.node_list = self._linearize_graph()
else:
self.node_list = self.get_node_list()
@abstractmethod
def solve(self):
"""Solve the checkpointing problem and return the solution.
"""
pass
def get_node_list(self):
"""Get the node list.
"""
return [[node] for node in self.graph.nodes]
def _linearize_graph(self) -> List[List[Node]]:
"""Linearizing the graph
Args:
graph (Graph): The computing graph to be optimized.
Returns:
List[List[Node]]: List of list, each inside list of Node presents
the actual 'node' in linearized manner.
Remarks:
Do merge the inplace ops into the previous node.
"""
# Common nodes are type of nodes that could be seen as attributes and remain
# unchanged throughout the whole model, it will be used several times by
# different blocks of model, so that it is hard for us to linearize the graph
# when we encounter those kinds of nodes. We let users to annotate some of the
# input as common node, such as attention mask, and the followings are some of
# the ops that could actually be seen as common nodes. With our common node prop,
# we could find some of the "real" common nodes (e.g. the real attention mask
# used in BERT and GPT), the rule is simple, for node who's parents are all common
# nodes or it's op belongs to the following operations, we view this node as a
# newly born common node.
# List of target name that could be seen as common node
common_ops = ["getattr", "getitem", "size"]
def _is_cop(target: Any) -> bool:
"""Check if an op could be seen as common node
Args:
target (Any): node target
Returns:
bool
"""
if isinstance(target, str):
return target in common_ops
else:
return target.__name__ in common_ops
def _is_sink() -> bool:
"""Check if we can free all dependencies
Returns:
bool
"""
return not sum([v for _, v in deps.items()]) and not any(map(is_inplace, n.users))
# make sure that item in cnode is valid
if self.cnode:
for name in self.cnode:
try:
assert next(node for node in self.graph.nodes if node.name == name).op == "placeholder", \
f"Common node {name} is not an input of the model."
except StopIteration:
raise ValueError(f"Common node name {name} not in graph.")
else:
self.cnode = []
deps = {}
node_list = []
region = []
for n in self.graph.nodes:
if n.op != "placeholder" and n.op != "output":
for n_par in n.all_input_nodes:
if n_par.op != "placeholder" and n_par.name not in self.cnode:
deps[n_par] -= 1
region.append(n)
# if the node could free all dependencies in graph
# we could begin a new node
if _is_sink():
node_list.append(region)
region = []
# propagate common node attr if possible
if len(n.all_input_nodes) == len([node for node in n.all_input_nodes if node.name in self.cnode
]) or _is_cop(n.target):
self.cnode.append(n.name)
else:
deps[n] = len([user for user in n.users if user.op != "output"])
return node_list

View File

@@ -0,0 +1,87 @@
import math
from copy import deepcopy
from typing import List, Set, Tuple
from torch.fx import Graph, Node
from colossalai.fx.profiler import calculate_fwd_in, calculate_fwd_tmp
from .ckpt_solver_base import CheckpointSolverBase
__all__ = ['CheckpointSolverChen']
class CheckpointSolverChen(CheckpointSolverBase):
def __init__(self, graph: Graph, cnode: List[str] = None, num_grids: int = 6):
"""
This is the simple implementation of Algorithm 3 in https://arxiv.org/abs/1604.06174.
Note that this algorithm targets at memory optimization only, using techniques in appendix A.
Usage:
Assume that we have a `GraphModule`, and we already applied the `MetaInfoProp`
to the graph to retrieve all information needed, then we could use the following
code to find a solution using `CheckpointSolverChen`:
>>> solver = CheckpointSolverChen(gm.graph)
>>> chen_graph = solver.solve()
>>> gm.graph = chen_graph # set the graph to a new graph
Args:
graph (Graph): The computing graph to be optimized.
cnode (List[str], optional): Common node List, should be the subset of input. Defaults to None.
num_grids (int, optional): Number of grids to search for b. Defaults to 6.
"""
super().__init__(graph, 0, 0, True, cnode)
self.num_grids = num_grids
def solve(self) -> Graph:
"""Solve the checkpointing problem using Algorithm 3.
Returns:
graph (Graph): The optimized graph, should be a copy of the original graph.
"""
checkpointable_op = ['call_module', 'call_method', 'call_function', 'get_attr']
ckpt = self.grid_search()
for i, seg in enumerate(ckpt):
for idx in range(*seg):
nodes = self.node_list[idx]
for n in nodes:
if n.op in checkpointable_op:
n.meta['activation_checkpoint'] = i
return deepcopy(self.graph)
def run_chen_greedy(self, b: int = 0) -> Tuple[Set, int]:
"""
This is the simple implementation of Algorithm 3 in https://arxiv.org/abs/1604.06174.
"""
ckpt_intv = []
temp = 0
x = 0
y = 0
prev_idx = 2
for idx, nodes in enumerate(self.node_list):
for n in nodes:
n: Node
temp += calculate_fwd_in(n) + calculate_fwd_tmp(n)
y = max(y, temp)
if temp > b and idx > prev_idx:
x += calculate_fwd_in(nodes[0])
temp = 0
ckpt_intv.append((prev_idx, idx + 1))
prev_idx = idx + 1
return ckpt_intv, math.floor(math.sqrt(x * y))
def grid_search(self) -> Set:
"""
Search ckpt strategy with b = 0, then run the allocation algorithm again with b = √xy.
Grid search over [√2/2 b, √2 b] for ckpt_opt over num_grids as in appendix A.
"""
_, b_approx = self.run_chen_greedy(0)
b_min, b_max = math.floor(b_approx / math.sqrt(2)), math.ceil(b_approx * math.sqrt(2))
b_opt = math.inf
for b in range(b_min, b_max, (b_max - b_min) // self.num_grids):
ckpt_intv, b_approx = self.run_chen_greedy(b)
if b_approx < b_opt:
b_opt = b_approx
ckpt_opt = ckpt_intv
return ckpt_opt

View File

@@ -0,0 +1,387 @@
from copy import deepcopy
from typing import Dict, List, Tuple
from torch import Tensor
from torch.fx import Graph, Node
from colossalai.fx.codegen.activation_checkpoint_codegen import _find_nested_ckpt_regions
from colossalai.fx.profiler import (
activation_size,
calculate_bwd_time,
calculate_fwd_out,
calculate_fwd_time,
calculate_fwd_tmp,
)
from colossalai.logging import get_dist_logger
from .ckpt_solver_base import CheckpointSolverBase
from .operation import Backward, Chain, ForwardCheck, ForwardEnable, ForwardNograd, Function, Loss, Sequence
__all__ = ['CheckpointSolverBase']
class CheckpointSolverRotor(CheckpointSolverBase):
def __init__(self,
graph: Graph,
memory_budget: float = -1,
parameter_size: float = 0,
cnode: List[str] = None,
memory_slots: int = 500):
"""This is the simple implementation of dynamic programming algorithm rotor
in https://hal.inria.fr/hal-02352969. Some code are adapted from
https://gitlab.inria.fr/hiepacs/rotor.
Usage:
Assume that we have a `GraphModule`, and we already applied the `MetaInfoProp`
to the graph to retrieve all information needed, then we could use the following
code to find a solution using `CheckpointSolverRotor`:
>>> solver = CheckpointSolverRotor(gm.graph, memory_budget=memory_budget, parameter_size=parameter_size)
>>> rotor_graph = solver.solve(force_python=True) # otherwise use C solver
>>> gm.graph = rotor_graph # set the graph to a new graph
Args:
graph (Graph): The computing graph to be optimized.
memory_budget (float, optional): Memory constraint for the solution, unit is byte.
parameter_size (float, optional): The size of parameter of this model, unit is byte. Use `parameter_size(model)` to estimate.
cnode (List[str], optional): Common node List, should be the subset of input. Defaults to None.
memory_slots (int, optional): Number of slots for discretizing memory budget. Defaults to 500.
"""
super().__init__(graph, memory_budget, parameter_size, True, cnode)
self.memory_slots = memory_slots
# construct chain
unit = self.memory_budget // self.memory_slots
self.chain = self._construct_chain(self.graph, self.node_list)
self.chain.discretize_all(unit)
self.cost_table = None
self.back_ptr = None
self.sequence = None
def solve(self, force_python: bool = False) -> Graph:
"""Solve the checkpointing problem using rotor algorithm.
Args:
force_python (bool, optional): Use Python version of solver, else use C version. Defaults to False.
Returns:
graph (Graph): The optimized graph, should be a copy of the original graph.
"""
chain = self.chain
# compute cost table
if force_python:
self.cost_table, self.back_ptr = self._compute_table(chain, self.memory_slots)
else:
self.cost_table, self.back_ptr = self._compute_table_c(chain, self.memory_slots)
# backtrack
try:
self.sequence = self._backtrack(chain, 0, chain.length, self.memory_slots, self.cost_table, self.back_ptr)
self._annotate_from_sequence(self.sequence, self.node_list)
except RuntimeError as e:
# using logger to annonce that the solver is failed
logger = get_dist_logger()
logger.warning(f'Checkpoint solver failed: {e}')
return deepcopy(self.graph)
def print_chain(self):
print('[input]', self.chain.x[0], self.chain.xbar[0], self.chain.ftmp[0], self.chain.btmp[0])
for idx in range(len(self.node_list) - 1):
print(self.node_list[idx], self.chain.x[idx + 1], self.chain.xbar[idx + 1], self.chain.ftmp[idx],
self.chain.btmp[idx])
print(f'Chain = {self.chain}')
def print_sequence(self):
print(f'Sequence = {self.sequence}')
@classmethod
def _construct_chain(cls, graph: Graph, node_list: List[List[Node]]) -> Chain:
input_tensors = cls._extract_input(graph)
fwd_time, bwd_time, ftmp, btmp = list(), list(), list(), list()
xbar, x = [activation_size(input_tensors)], [activation_size(input_tensors)]
for idx, node in enumerate(node_list):
node_info = cls._extract_node_info(node)
fwd_time.append(node_info[0])
bwd_time.append(node_info[1])
x.append(node_info[2])
xbar.append(node_info[3])
ftmp.append(node_info[4])
btmp.append(node_info[5])
# currently we view loss backward temp as zero
bwd_time.append(0)
btmp.append(0)
return Chain(fwd_time, bwd_time, x, xbar, ftmp, btmp)
@classmethod
def _extract_node_info(cls, node: List[Node]) -> Tuple[int, ...]:
"""Extract node info from a list of nodes"""
xbar = 0
fwd_time = 0
bwd_time = 0
for n in node:
assert isinstance(n, Node), f'{n} is not a Node'
xbar += calculate_fwd_tmp(n) + calculate_fwd_out(n)
# minimum flop count is required
fwd_time += max(calculate_fwd_time(n), 1.0)
bwd_time += max(calculate_bwd_time(n), 1.0)
x = calculate_fwd_out(node[-1])
xbar = max(x, xbar)
ftmp = cls._extract_ftmp(node)
btmp = cls._extract_btmp(node)
return fwd_time, bwd_time, x, xbar, ftmp, btmp
@staticmethod
def _extract_input(graph: Graph) -> Tuple[Tensor, ...]:
"""Extract input tensors from a Graph"""
input_tensors = []
for node in graph.nodes:
if node.op == 'placeholder':
input_tensors.append(node.meta['fwd_out'])
return input_tensors
@staticmethod
def _extract_ftmp(node: List[Node]) -> int:
"""Extract ftmp from a list of nodes"""
n = node[-1]
return activation_size(n.meta['fwd_out']) - calculate_fwd_out(n)
@staticmethod
def _extract_btmp(node: List[Node]) -> int:
"""Extract btmp from a list of nodes"""
def _extract_deps_size():
deps_size = 0
for k, v in deps.items():
k: Node
if v > 0:
deps_size += k.meta['bwd_mem_out']
if v == float('-inf'):
deps_size -= calculate_fwd_tmp(k) + calculate_fwd_out(k)
return deps_size
btmp = 0
deps = {}
for n in reversed(node):
deps[n] = len(n.all_input_nodes)
btmp = max(btmp, _extract_deps_size() + n.meta['bwd_mem_tmp'])
for child in n.users:
if child in deps:
deps[child] -= 1
if deps[child] <= 0:
deps[child] = float('-inf') # free
return btmp
@staticmethod
def _compute_table(chain: Chain, mem_slots: int) -> Tuple:
"""Compute the table using dynamic programming. Returns the cost table and the backtracking pointer.
Args:
chain (Chain): A basic linearized structure for solving the dynamic programming problem.
mem_slots (int): Number of slots for discretizing memory budget.
Returns:
cost_table (List[List[Dict[int, Tuple]]]): cost_table[m][lmin][lmax] with lmin = 0...chain.length
and lmax = lmin...chain.length (lmax is not included) and m = 0...mmax
back_ptr (List[List[Dict[int, Tuple]]]): back_ptr[m][lmin][lmax] is (True,) if the optimal choice
is a chain checkpoint (False, j) if the optimal choice is a leaf checkpoint
of length j
"""
ftime = chain.ftime + [0.0]
btime = chain.btime
x = chain.x + [0]
xbar = chain.xbar + [0]
ftmp = chain.ftmp + [0]
btmp = chain.btmp + [0]
# Build table
cost_table = [[{} for _ in range(chain.length + 1)] for _ in range(mem_slots + 1)]
back_ptr = [[{} for _ in range(chain.length + 1)] for _ in range(mem_slots + 1)]
# Last one is a dict because its indices go from i to l. Renumbering will wait for C implementation
# Initialize borders of the tables for lmax-lmin = 0
for m in range(mem_slots + 1):
for i in range(chain.length + 1):
limit = max(x[i + 1] + xbar[i + 1] + ftmp[i], x[i + 1] + xbar[i + 1] + btmp[i])
if m >= limit: # Equation (1)
cost_table[m][i][i] = ftime[i] + btime[i]
else:
cost_table[m][i][i] = float("inf")
# Compute everything
for m in range(mem_slots + 1):
for d in range(1, chain.length + 1):
for i in range(chain.length + 1 - d):
idx = i + d
mmin = x[idx + 1] + x[i + 1] + ftmp[i]
if idx > i + 1:
mmin = max(mmin, x[idx + 1] + max(x[j] + x[j + 1] + ftmp[j] for j in range(i + 1, idx)))
if m < mmin:
cost_table[m][i][idx] = float("inf")
else:
leaf_checkpoints = [(j,
sum(ftime[i:j]) + cost_table[m - x[j]][j][idx] + cost_table[m][i][j - 1])
for j in range(i + 1, idx + 1)
if m >= x[j]]
if leaf_checkpoints:
best_leaf = min(leaf_checkpoints, key=lambda t: t[1])
else:
best_leaf = None
if m >= xbar[i + 1]:
chain_checkpoint = cost_table[m][i][i] + cost_table[m - xbar[i + 1]][i + 1][idx]
else:
chain_checkpoint = float("inf")
if best_leaf and best_leaf[1] <= chain_checkpoint:
cost_table[m][i][idx] = best_leaf[1]
back_ptr[m][i][idx] = (False, best_leaf[0])
else:
cost_table[m][i][idx] = chain_checkpoint
back_ptr[m][i][idx] = (True,)
return cost_table, back_ptr
@staticmethod
def _compute_table_c(chain: Chain, mem_slots: int) -> Tuple:
raise NotImplementedError("C implementation not available yet")
def _backtrack(self, chain: Chain, lmin: int, lmax: int, mem_budget: int, cost_table: List[List[Dict[int, Tuple]]],
back_ptr: List[List[Dict[int, int]]]) -> List[int]:
"""Backtrack the cost table and retrieve the optimal checkpointing strategy.
Args:
chain (Chain): A basic linearized structure for solving the dynamic programming problem.
lmin (int): The left index of the interval to backtrack.
lmax (int): The right index of the interval to backtrack.
mem_budget (int): The memory budget for processing this interval.
cost_table (List[List[Dict[int, Tuple]]]): See _compute_table() for definitions
back_ptr (List[List[Dict[int, Tuple]]]): See _compute_table() for definitions
Raises:
ValueError: Can not process the chain.
Returns:
sequence (Sequence): The sequence of executing nodes with checkpoints.
"""
if mem_budget <= 0:
raise ValueError(f"Can not process a chain with negative memory {mem_budget}")
elif cost_table[mem_budget][lmin][lmax] == float("inf"):
raise ValueError(f"Can not process this chain from index {lmin} to {lmax} with memory {mem_budget}")
sequence = Sequence(Function("Persistent", lmax - lmin, mem_budget))
if lmin == lmax:
if lmin == chain.length:
sequence.insert(Loss())
else:
sequence.insert(ForwardEnable(lmin))
sequence.insert(Backward(lmin))
return sequence
if back_ptr[mem_budget][lmin][lmax][0]:
sequence.insert(ForwardEnable(lmin))
sequence.insert_sequence(
self._backtrack(chain, lmin + 1, lmax, mem_budget - chain.xbar[lmin + 1], cost_table, back_ptr))
sequence.insert(Backward(lmin))
else:
j = back_ptr[mem_budget][lmin][lmax][1]
sequence.insert(ForwardCheck(lmin))
for k in range(lmin + 1, j):
sequence.insert(ForwardNograd(k))
sequence.insert_sequence(self._backtrack(chain, j, lmax, mem_budget - chain.xbar[j], cost_table, back_ptr))
sequence.insert_sequence(self._backtrack(chain, lmin, j - 1, mem_budget, cost_table, back_ptr))
return sequence
@staticmethod
def _annotate_from_sequence(sequence: Sequence, node_list: List[List[Node]]):
op_list = sequence.list_operations()
loss_op = next(op for op in op_list if isinstance(op, Loss))
fwd_list = op_list[:op_list.index(loss_op)]
bwd_list = op_list[op_list.index(loss_op) + 1:]
ckpt_idx = 0
in_ckpt = False
ckpt_region = []
# forward annotation
for idx, op in enumerate(fwd_list, 0):
if in_ckpt:
if isinstance(op, ForwardNograd):
ckpt_region.append(idx)
elif isinstance(op, ForwardEnable):
in_ckpt = False
for node_idx in ckpt_region:
for n in node_list[node_idx]:
n.meta['activation_checkpoint'] = [ckpt_idx]
ckpt_idx += 1
ckpt_region = []
elif isinstance(op, ForwardCheck):
for node_idx in ckpt_region:
for n in node_list[node_idx]:
n.meta['activation_checkpoint'] = [ckpt_idx]
ckpt_idx += 1
ckpt_region = [idx]
else:
if isinstance(op, ForwardCheck):
in_ckpt = True
ckpt_region.append(idx)
# annotate the backward if there is any nested activation checkpoint
in_recompute = False
for op in bwd_list:
if in_recompute:
if isinstance(op, ForwardNograd):
ckpt_region.append(op.index)
elif isinstance(op, ForwardEnable):
for node_idx in ckpt_region:
for n in node_list[node_idx]:
n.meta['activation_checkpoint'].append(ckpt_idx)
ckpt_idx += 1
ckpt_region = []
elif isinstance(op, ForwardCheck):
for node_idx in ckpt_region:
for n in node_list[node_idx]:
n.meta['activation_checkpoint'].append(ckpt_idx)
ckpt_idx += 1
ckpt_region = [op.index]
elif isinstance(op, Backward):
for node_idx in ckpt_region:
for n in node_list[node_idx]:
n.meta['activation_checkpoint'].append(ckpt_idx)
in_recompute = False
else:
if not isinstance(op, Backward):
in_recompute = True
ckpt_idx = 0
ckpt_region = []
if isinstance(op, ForwardCheck):
ckpt_region.append(op.index)
# postprocess, make sure every activation checkpoint label in the
# same activation checkpoint region (level = 0) has the same length
op_list = []
for node in node_list:
op_list += node
ckpt_regions = _find_nested_ckpt_regions(op_list)
for (start_idx, end_idx) in ckpt_regions:
nested_length = max(
len(op_list[idx].meta['activation_checkpoint']) for idx in range(start_idx, end_idx + 1))
for idx in range(start_idx, end_idx + 1):
op_list[idx].meta['activation_checkpoint'] += [None] * (nested_length -
len(op_list[idx].meta['activation_checkpoint']))

View File

@@ -0,0 +1,241 @@
import math
from abc import ABC
from typing import List
from torch.utils._pytree import tree_map
class Chain:
def __init__(self,
ftime: List[float],
btime: List[float],
x: List[int],
xbar: List[int],
ftmp: List[int],
btmp: List[int],
check_consistency: bool = True):
"""The chain is a basic linearized structure for solving the dynamic programming problem for activation checkpoint.
See paper https://hal.inria.fr/hal-02352969 for details.
Args:
ftime (List[float]): The forward time of each node.
btime (List[float]): The backward time of each node.
x (List[int]): The forward memory of each node (if save_output). Same as `a` in the paper.
xbar (List[int]): The forward memory of each node (if save_all). Same as `a_bar` in the paper.
ftmp (List[int]): The temporary forward memory of each node.
btmp (List[int]): The temporary backward memory of each node, can be used to control memory budget.
check_consistency (bool, optional): Check the lengths consistency for the `Chain`. Defaults to True.
"""
self.ftime = ftime
self.btime = btime
self.x = x
self.xbar = xbar
self.ftmp = ftmp
self.btmp = btmp
self.length = len(ftime)
if check_consistency and not self.check_lengths():
raise AttributeError("In Chain, input lists do not have consistent lengths")
def check_lengths(self):
return ((len(self.ftime) == self.length) and (len(self.btime) == self.length + 1)
and (len(self.x) == self.length + 1) and (len(self.ftmp) == self.length)
and (len(self.btmp) == self.length + 1) and (len(self.xbar) == self.length + 1))
def __repr__(self):
chain_list = []
for i in range(self.length):
chain_list.append((self.ftime[i], self.btime[i], self.x[i], self.xbar[i], self.ftmp[i], self.btmp[i]))
i = self.length
chain_list.append((None, self.btime[i], self.x[i], self.xbar[i], None, self.btmp[i]))
return chain_list.__repr__()
def discretize_all(self, unit: int):
"""Discretize the chain into a list of chains according to unit size."""
discretizer = lambda val: math.ceil(val / unit)
self.x = tree_map(discretizer, self.x)
self.xbar = tree_map(discretizer, self.xbar)
self.ftmp = tree_map(discretizer, self.ftmp)
self.btmp = tree_map(discretizer, self.btmp)
class Operation(ABC):
name = "Op"
def __repr__(self) -> str:
return f"{self.name}_{self.index}"
def shift(self, value):
if type(self.index) is tuple:
self.index = tuple(x + value for x in self.index)
else:
self.index += value
class Forward(Operation):
name = "F"
def __init__(self, index):
self.index = index
def cost(self, chain: Chain):
if chain is not None:
return chain.ftime[self.index]
else:
return 1
class ForwardEnable(Forward):
name = "Fe"
class ForwardNograd(Forward):
name = "Fn"
class ForwardCheck(Forward):
name = "CF"
class Forwards(Operation):
def __init__(self, start, end):
self.index = (start, end)
def __repr__(self):
return "F_{i}->{j}".format(i=self.index[0], j=self.index[1])
def cost(self, chain: Chain):
if chain is not None:
return sum(chain.ftime[self.index[0]:self.index[1] + 1])
else:
return (self.index[1] - self.index[0] + 1)
def isForward(op):
return type(op) is Forward or type(op) is Forwards
class Backward(Operation):
name = "B"
def __init__(self, index):
self.index = index
def cost(self, chain: Chain):
if chain is not None:
return chain.btime[self.index]
else:
return 1
class Loss(Operation):
def __init__(self):
pass
def __repr__(self):
return "L"
def cost(self, chain):
return 0
class MemoryAccess(Operation):
name = "MA"
def __init__(self, index):
self.index = index
def cost(self, chain: Chain):
return 0
class WriteMemory(MemoryAccess):
name = "WM"
class ReadMemory(MemoryAccess):
name = "RM"
class DiscardMemory(MemoryAccess):
name = "DM"
class Function:
def __init__(self, name, *args):
self.name = name
self.args = args
self.str_args = ','.join(str(v) for v in self.args)
def __repr__(self):
return "{n}({args})".format(n=self.name, args=self.str_args)
class Sequence:
def __init__(self, function):
self.sequence = [] #List of Operation and Sequence
self.function = function #Description the function (name and parameters)
def __repr__(self):
return repr(self.list_operations())
def list_operations(self):
op_list = []
for x in self.sequence:
if isinstance(x, Operation):
op_list.append(x)
else:
assert isinstance(x, Sequence)
op_list += x.list_operations()
return op_list
def insert(self, operation):
self.sequence.append(operation)
def remove(self, operation_index):
del self.sequence[operation_index]
def insert_sequence(self, sequence):
self.sequence.append(sequence)
def shift(self, value):
for x in self.sequence:
x.shift(value)
return self
def remove_useless_write(self):
if self.sequence:
if isinstance(self.sequence[0], WriteMemory):
self.remove(0)
return self
def get_makespan(self, chain):
return sum(op.cost(chain) for op in self.list_operations())
def without_suffix(self):
ops = self.list_operations()
end_of_first_phase = [i for i in range(len(ops)) if type(ops[i]) is Loss][0]
try:
last_idx = max(i for i in range(end_of_first_phase) if not type(ops[i]) is ForwardEnable)
except ValueError:
last_idx = -1
if last_idx == end_of_first_phase - 1:
return (self, None)
chain_length = ops[end_of_first_phase -
1].index ## Some assumption here about the sequence (finishes with Forward_L
start_of_fwd_enable_chain = ops[last_idx + 1].index ## And starts with B_L), but should be fine in practice
result = Sequence(Function("Strip", self.function.name, *self.function.args, start_of_fwd_enable_chain))
for i in range(last_idx + 1):
result.insert(ops[i])
result.insert(Loss())
for i in range(chain_length, start_of_fwd_enable_chain - 1, -1):
position = end_of_first_phase + 1 + (chain_length - i)
assert type(ops[position]) is Backward
assert ops[position].index == i
for i in range(end_of_first_phase + 1 + 1 + chain_length - start_of_fwd_enable_chain, len(ops)):
result.insert(ops[i])
return (result, start_of_fwd_enable_chain)