[fx/tuning] tune performance on rotor with meta info. (#1599)

This commit is contained in:
Super Daniel
2022-09-15 14:46:36 +08:00
committed by GitHub
parent a7cda6f57d
commit cd5cf2bcc9
7 changed files with 96 additions and 107 deletions

View File

@@ -1,8 +1,7 @@
from typing import List, Tuple
import torch
from torch.fx import GraphModule, Node
from torch.fx import Node
from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.profiler import parameter_size
from colossalai.fx.profiler import activation_size, parameter_size
import math
from .linearize import linearize
from .utils import *
@@ -31,7 +30,7 @@ def _compute_table(chain: Chain, mmax) -> Tuple:
# Build table
opt = [[{} for _ in range(chain.length + 1)] for _ in range(mmax + 1)]
what = [[{} for _ in range(chain.length + 1)] for _ in range(mmax + 1)]
## Last one is a dict because its indices go from i to l. Renumbering will wait for C implementation
# Last one is a dict because its indices go from i to l. Renumbering will wait for C implementation
# Initialize borders of the tables for lmax-lmin = 0
for m in range(mmax + 1):
@@ -115,43 +114,6 @@ def _discretize(mem_unit, values):
return [math.ceil(value / mem_unit) for value in values]
def _compute_size(obj: torch.Tensor) -> int:
return obj.numel() * obj.element_size()
def _compute_output_size(node: List[Node]) -> int:
"""Compute the output size of a node
Args:
node (List[Node]): node, list of torch.fx.Node
Returns:
int: output size
"""
return node[-1].meta['tensor_meta'].numel * torch.tensor([],
dtype=node[-1].meta['tensor_meta'].dtype).element_size()
def _get_inplace(node: Node) -> bool:
"""Get the inplace argument from torch.fx.Node
Args:
node (Node): torch.fx.Node
Returns:
bool: indicates whether this op is inplace
"""
is_inplace = False
if node.op == "call_function":
is_inplace = node.kwargs.get("inplace", False)
elif node.op == "call_module":
is_inplace = getattr(node.graph.owning_module.get_submodule(node.target), "inplace", False)
return is_inplace
def _fwd_xbar(node: List[Node]) -> int:
"""Get the forward xbar of a node
@@ -221,46 +183,33 @@ def _get_bwd_mem_tmp(node: List[Node]) -> int:
for k, v in deps.items():
if v > 0:
deps_size += k.meta['bwd_mem_out']
if v == float('-inf'):
deps_size -= k.meta['fwd_mem_tmp'] + k.meta['fwd_mem_out']
return deps_size
bwd_mem_tmp = 0
deps = {}
# add all the users for last node into deps,
# as those nodes' gradient out will be stored in memory
for child in node[-1].users:
deps[child] = 1
for n in reversed(node):
deps[n] = len(n.all_input_nodes)
bwd_mem_tmp = max(bwd_mem_tmp, _get_deps_size() + n.meta['bwd_mem_tmp'])
deps[n] = len(n.all_input_nodes)
for child in n.users:
if child in deps:
deps[child] -= 1
for key in list(deps.keys()):
if deps[key] == 0:
del deps[key]
if deps[child] <= 0:
deps[child] = float('-inf') # free
return bwd_mem_tmp
def _construct_chain(node_list: List[List[Node]], data, mem_unit: int) -> Chain:
def _construct_chain(node_list: List[List[Node]], input, mem_unit: int) -> Chain:
fwd_time = []
bwd_time = []
if isinstance(data, torch.Tensor):
xbar_sizes = [_compute_size(data)]
x_sizes = [_compute_size(data)]
elif isinstance(data, list) or isinstance(data, tuple):
xbar_sizes = [sum([_compute_size(obj) for obj in data])]
x_sizes = [sum([_compute_size(obj) for obj in data])]
elif isinstance(data, dict):
xbar_sizes = [sum([_compute_size(obj) for obj in data.values()])]
x_sizes = [sum([_compute_size(obj) for obj in data.values()])]
xbar_sizes = [activation_size(input)]
x_sizes = [activation_size(input)]
# currently we can't get the temp memory needed in fwd
tmp_fwd = [0] * len(node_list)
tmp_bwd = []
@@ -268,14 +217,10 @@ def _construct_chain(node_list: List[List[Node]], data, mem_unit: int) -> Chain:
for idx, node in enumerate(node_list):
fwd_time.append(_fwd_time(node))
bwd_time.append(_bwd_time(node))
x_sizes.append(_compute_output_size(node))
x_sizes.append(node[-1].meta['fwd_mem_out'])
xbar_sizes.append(max(x_sizes[-1], _fwd_xbar(node)))
tmp_bwd.append(_get_bwd_mem_tmp(node))
# if a node with only one inplace op, we need to let x_bar = 0
if len(node) == 1 and _get_inplace(node[0]):
xbar_sizes[-1] = 0
bwd_time.append(0)
# currently we view loss backward temp as zero
@@ -381,7 +326,7 @@ def solver_rotor(gm: ColoGraphModule,
mem_limit: int,
mem_slots: int = 500,
cnode: List[str] = None,
eps: float = 0.02) -> ColoGraphModule:
eps: float = 0.0) -> ColoGraphModule:
"""solver that automatically find activation checkpoint in rotor's manner
Args:
@@ -390,7 +335,7 @@ def solver_rotor(gm: ColoGraphModule,
mem_limit (int): memory budget in Byte.
mem_slots (int, optional): number of slots for discretizing memory budget. Defaults to 500.
cnode (List[Node], optional): common node list for linearize. Defaults to None.
eps (float): epsilon for memory decay. Defaults to 0.02
eps (float): epsilon for memory decay. Defaults to 0.0
Returns:
ColoGraphModule: annotated ColoGraphModuled with __sequence__ attribute

View File

@@ -1,5 +1,6 @@
from typing import List, Any
from torch.fx import GraphModule, Node
from colossalai.fx.profiler import is_inplace
# Common nodes are type of nodes that could be seen as attributes and remain
# unchanged throughout the whole model, it will be used several times by
@@ -41,6 +42,9 @@ def linearize(gm: GraphModule, cnode: List[str] = None) -> List[List[Node]]:
Returns:
List[List[Node]]: List of list, each inside list of Node presents
the actual 'node' in linearized manner.
Remarks:
We merge the inplace ops into the previous node.
"""
def _is_sink() -> bool:
@@ -50,7 +54,7 @@ def linearize(gm: GraphModule, cnode: List[str] = None) -> List[List[Node]]:
bool
"""
return not sum([v for _, v in deps.items()])
return not sum([v for _, v in deps.items()]) and not any(map(is_inplace, n.users))
# make sure that item in cnode is valid
if cnode: