mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 17:17:05 +00:00
[fx/tuning] tune performance on rotor with meta info. (#1599)
This commit is contained in:
@@ -1,8 +1,7 @@
|
||||
from typing import List, Tuple
|
||||
import torch
|
||||
from torch.fx import GraphModule, Node
|
||||
from torch.fx import Node
|
||||
from colossalai.fx.graph_module import ColoGraphModule
|
||||
from colossalai.fx.profiler import parameter_size
|
||||
from colossalai.fx.profiler import activation_size, parameter_size
|
||||
import math
|
||||
from .linearize import linearize
|
||||
from .utils import *
|
||||
@@ -31,7 +30,7 @@ def _compute_table(chain: Chain, mmax) -> Tuple:
|
||||
# Build table
|
||||
opt = [[{} for _ in range(chain.length + 1)] for _ in range(mmax + 1)]
|
||||
what = [[{} for _ in range(chain.length + 1)] for _ in range(mmax + 1)]
|
||||
## Last one is a dict because its indices go from i to l. Renumbering will wait for C implementation
|
||||
# Last one is a dict because its indices go from i to l. Renumbering will wait for C implementation
|
||||
|
||||
# Initialize borders of the tables for lmax-lmin = 0
|
||||
for m in range(mmax + 1):
|
||||
@@ -115,43 +114,6 @@ def _discretize(mem_unit, values):
|
||||
return [math.ceil(value / mem_unit) for value in values]
|
||||
|
||||
|
||||
def _compute_size(obj: torch.Tensor) -> int:
|
||||
return obj.numel() * obj.element_size()
|
||||
|
||||
|
||||
def _compute_output_size(node: List[Node]) -> int:
|
||||
"""Compute the output size of a node
|
||||
|
||||
Args:
|
||||
node (List[Node]): node, list of torch.fx.Node
|
||||
|
||||
Returns:
|
||||
int: output size
|
||||
"""
|
||||
|
||||
return node[-1].meta['tensor_meta'].numel * torch.tensor([],
|
||||
dtype=node[-1].meta['tensor_meta'].dtype).element_size()
|
||||
|
||||
|
||||
def _get_inplace(node: Node) -> bool:
|
||||
"""Get the inplace argument from torch.fx.Node
|
||||
|
||||
Args:
|
||||
node (Node): torch.fx.Node
|
||||
|
||||
Returns:
|
||||
bool: indicates whether this op is inplace
|
||||
"""
|
||||
|
||||
is_inplace = False
|
||||
if node.op == "call_function":
|
||||
is_inplace = node.kwargs.get("inplace", False)
|
||||
elif node.op == "call_module":
|
||||
is_inplace = getattr(node.graph.owning_module.get_submodule(node.target), "inplace", False)
|
||||
|
||||
return is_inplace
|
||||
|
||||
|
||||
def _fwd_xbar(node: List[Node]) -> int:
|
||||
"""Get the forward xbar of a node
|
||||
|
||||
@@ -221,46 +183,33 @@ def _get_bwd_mem_tmp(node: List[Node]) -> int:
|
||||
for k, v in deps.items():
|
||||
if v > 0:
|
||||
deps_size += k.meta['bwd_mem_out']
|
||||
if v == float('-inf'):
|
||||
deps_size -= k.meta['fwd_mem_tmp'] + k.meta['fwd_mem_out']
|
||||
|
||||
return deps_size
|
||||
|
||||
bwd_mem_tmp = 0
|
||||
deps = {}
|
||||
|
||||
# add all the users for last node into deps,
|
||||
# as those nodes' gradient out will be stored in memory
|
||||
for child in node[-1].users:
|
||||
deps[child] = 1
|
||||
for n in reversed(node):
|
||||
deps[n] = len(n.all_input_nodes)
|
||||
bwd_mem_tmp = max(bwd_mem_tmp, _get_deps_size() + n.meta['bwd_mem_tmp'])
|
||||
|
||||
deps[n] = len(n.all_input_nodes)
|
||||
for child in n.users:
|
||||
if child in deps:
|
||||
deps[child] -= 1
|
||||
|
||||
for key in list(deps.keys()):
|
||||
if deps[key] == 0:
|
||||
del deps[key]
|
||||
if deps[child] <= 0:
|
||||
deps[child] = float('-inf') # free
|
||||
|
||||
return bwd_mem_tmp
|
||||
|
||||
|
||||
def _construct_chain(node_list: List[List[Node]], data, mem_unit: int) -> Chain:
|
||||
def _construct_chain(node_list: List[List[Node]], input, mem_unit: int) -> Chain:
|
||||
|
||||
fwd_time = []
|
||||
bwd_time = []
|
||||
|
||||
if isinstance(data, torch.Tensor):
|
||||
xbar_sizes = [_compute_size(data)]
|
||||
x_sizes = [_compute_size(data)]
|
||||
elif isinstance(data, list) or isinstance(data, tuple):
|
||||
xbar_sizes = [sum([_compute_size(obj) for obj in data])]
|
||||
x_sizes = [sum([_compute_size(obj) for obj in data])]
|
||||
elif isinstance(data, dict):
|
||||
xbar_sizes = [sum([_compute_size(obj) for obj in data.values()])]
|
||||
x_sizes = [sum([_compute_size(obj) for obj in data.values()])]
|
||||
|
||||
xbar_sizes = [activation_size(input)]
|
||||
x_sizes = [activation_size(input)]
|
||||
# currently we can't get the temp memory needed in fwd
|
||||
tmp_fwd = [0] * len(node_list)
|
||||
tmp_bwd = []
|
||||
@@ -268,14 +217,10 @@ def _construct_chain(node_list: List[List[Node]], data, mem_unit: int) -> Chain:
|
||||
for idx, node in enumerate(node_list):
|
||||
fwd_time.append(_fwd_time(node))
|
||||
bwd_time.append(_bwd_time(node))
|
||||
x_sizes.append(_compute_output_size(node))
|
||||
x_sizes.append(node[-1].meta['fwd_mem_out'])
|
||||
xbar_sizes.append(max(x_sizes[-1], _fwd_xbar(node)))
|
||||
tmp_bwd.append(_get_bwd_mem_tmp(node))
|
||||
|
||||
# if a node with only one inplace op, we need to let x_bar = 0
|
||||
if len(node) == 1 and _get_inplace(node[0]):
|
||||
xbar_sizes[-1] = 0
|
||||
|
||||
bwd_time.append(0)
|
||||
|
||||
# currently we view loss backward temp as zero
|
||||
@@ -381,7 +326,7 @@ def solver_rotor(gm: ColoGraphModule,
|
||||
mem_limit: int,
|
||||
mem_slots: int = 500,
|
||||
cnode: List[str] = None,
|
||||
eps: float = 0.02) -> ColoGraphModule:
|
||||
eps: float = 0.0) -> ColoGraphModule:
|
||||
"""solver that automatically find activation checkpoint in rotor's manner
|
||||
|
||||
Args:
|
||||
@@ -390,7 +335,7 @@ def solver_rotor(gm: ColoGraphModule,
|
||||
mem_limit (int): memory budget in Byte.
|
||||
mem_slots (int, optional): number of slots for discretizing memory budget. Defaults to 500.
|
||||
cnode (List[Node], optional): common node list for linearize. Defaults to None.
|
||||
eps (float): epsilon for memory decay. Defaults to 0.02
|
||||
eps (float): epsilon for memory decay. Defaults to 0.0
|
||||
|
||||
Returns:
|
||||
ColoGraphModule: annotated ColoGraphModuled with __sequence__ attribute
|
||||
|
@@ -1,5 +1,6 @@
|
||||
from typing import List, Any
|
||||
from torch.fx import GraphModule, Node
|
||||
from colossalai.fx.profiler import is_inplace
|
||||
|
||||
# Common nodes are type of nodes that could be seen as attributes and remain
|
||||
# unchanged throughout the whole model, it will be used several times by
|
||||
@@ -41,6 +42,9 @@ def linearize(gm: GraphModule, cnode: List[str] = None) -> List[List[Node]]:
|
||||
Returns:
|
||||
List[List[Node]]: List of list, each inside list of Node presents
|
||||
the actual 'node' in linearized manner.
|
||||
|
||||
Remarks:
|
||||
We merge the inplace ops into the previous node.
|
||||
"""
|
||||
|
||||
def _is_sink() -> bool:
|
||||
@@ -50,7 +54,7 @@ def linearize(gm: GraphModule, cnode: List[str] = None) -> List[List[Node]]:
|
||||
bool
|
||||
"""
|
||||
|
||||
return not sum([v for _, v in deps.items()])
|
||||
return not sum([v for _, v in deps.items()]) and not any(map(is_inplace, n.users))
|
||||
|
||||
# make sure that item in cnode is valid
|
||||
if cnode:
|
||||
|
Reference in New Issue
Block a user