[fx] Modify solver linearize and add corresponding test (#1531)

* [fx] modify solver linearize and add test

* [fx] add torch11 test of linearize but skip it

* [fx] remove some unused imports
This commit is contained in:
Boyuan Yao
2022-09-02 10:24:41 +08:00
committed by GitHub
parent 7dc53237c3
commit 56159049e8
4 changed files with 181 additions and 99 deletions

View File

@@ -114,7 +114,7 @@ def _discretize(mem_unit, values):
return [math.ceil(value / mem_unit) for value in values]
def _construct_chain(node_dict: Dict[int, Node], data: torch.Tensor, mem_unit: int) -> Chain:
def _construct_chain(node_list: List[List[Node]], data: torch.Tensor, mem_unit: int) -> Chain:
fwd_time = []
bwd_time = []
@@ -122,22 +122,22 @@ def _construct_chain(node_dict: Dict[int, Node], data: torch.Tensor, mem_unit: i
x_sizes = [data.numel() * data.element_size()]
# currently we can't get the temp memory needed in fwd and bwd
tmp_fwd = [0] * len(node_dict)
tmp_bwd = [0] * (len(node_dict) + 1)
tmp_fwd = [0] * len(node_list)
tmp_bwd = [0] * (len(node_list) + 1)
for key in node_dict.keys():
for idx, node in enumerate(node_list):
fwd_time.append(0)
bwd_time.append(0)
xbar_sizes.append(0)
x_sizes.append(node_dict[key][-1].meta['tensor_meta'].numel *
torch.tensor([], dtype=node_dict[key][-1].meta['tensor_meta'].dtype).element_size())
for node in node_dict[key]:
fwd_time[-1] += max(node.__flops__, 1)
x_sizes.append(node[-1].meta['tensor_meta'].numel *
torch.tensor([], dtype=node[-1].meta['tensor_meta'].dtype).element_size())
for n in node:
fwd_time[-1] += max(n.__flops__, 1)
# currently we haven't patched the backward flops count
bwd_time[-1] += max(node.__flops__ * 2, 2)
bwd_time[-1] += max(n.__flops__ * 2, 2)
xbar_sizes[-1] += node.__activation__
xbar_sizes[-1] += n.__activation__
xbar_sizes[-1] = max(xbar_sizes[-1], x_sizes[-1])
@@ -151,14 +151,14 @@ def _construct_chain(node_dict: Dict[int, Node], data: torch.Tensor, mem_unit: i
return Chain(fwd_time, bwd_time, x_sizes, xbar_sizes, tmp_fwd, tmp_bwd)
def _annotate_from_sequence(sequence: Sequence, node_dict: Dict[int, Node]) -> GraphModule:
def _annotate_from_sequence(sequence: Sequence, node_list: List[List[Node]]) -> GraphModule:
op_list = sequence.list_operations()
loss_op = [op for op in op_list if isinstance(op, Loss)][0]
loss_op = next(op for op in op_list if isinstance(op, Loss))
op_list = op_list[:op_list.index(loss_op)]
ckpt_idx = 0
in_ckpt = False
ckpt_region = []
for idx, op in enumerate(op_list, 1):
for idx, op in enumerate(op_list, 0):
if in_ckpt:
if isinstance(op, ForwardNograd):
ckpt_region.append(idx)
@@ -166,16 +166,16 @@ def _annotate_from_sequence(sequence: Sequence, node_dict: Dict[int, Node]) -> G
elif isinstance(op, ForwardEnable):
in_ckpt = False
for node_idx in ckpt_region:
for node in node_dict[node_idx]:
setattr(node, "activation_checkpoint", ckpt_idx)
for n in node_list[node_idx]:
setattr(n, "activation_checkpoint", ckpt_idx)
ckpt_idx += 1
ckpt_region = []
elif isinstance(op, ForwardCheck):
for node_idx in ckpt_region:
for node in node_dict[node_idx]:
setattr(node, "activation_checkpoint", ckpt_idx)
for n in node_list[node_idx]:
setattr(n, "activation_checkpoint", ckpt_idx)
ckpt_idx += 1
ckpt_region = [idx]
@@ -199,13 +199,13 @@ def solver_rotor(gm: ColoGraphModule, data: torch.Tensor, mem_limit: int, mem_sl
ColoGraphModule: annotated ColoGraphModuled with __sequence__ attribute
"""
node_dict = linearize(gm)
node_list = linearize(gm)
mem_unit = mem_limit // mem_slots
MetaInfoProp(gm).run(data)
chain: Chain = _construct_chain(node_dict, data, mem_unit)
chain: Chain = _construct_chain(node_list, data, mem_unit)
opt_table = _compute_table(chain, mem_slots)
sequence = _rec(chain, 0, chain.length, mem_slots - chain.cweight[0], opt_table)
_annotate_from_sequence(sequence, node_dict)
_annotate_from_sequence(sequence, node_list)
# set __sequence__ attribute to GraphModule
setattr(gm, "__sequence__", sequence)

View File

@@ -1,89 +1,44 @@
from typing import OrderedDict
from torch.fx import GraphModule
from collections import OrderedDict
import pdb
from typing import List
from torch.fx import GraphModule, Node
def linearize(gm: GraphModule) -> dict:
status_dict = {}
node_dict = OrderedDict()
node_idx = 0
for node in gm.graph.nodes:
last_dict_len = len(status_dict)
# remove node from users list in status_dict
for item in status_dict.values():
if node in item:
item.remove(node)
def linearize(gm: GraphModule) -> List[List[Node]]:
"""Linearizing the graph
# pop node from status_dict if it is fully used
for key in list(status_dict):
if len(status_dict[key]) == 0:
status_dict.pop(key)
Args:
gm (GraphModule): GraphModule derived by tracing
# first node in graph, it should be in n0-n1 type,
# where n0 contains only input op, i.e. placeholder
if last_dict_len == 0:
node_dict[node_idx] = [node]
status_dict[node.name] = list(node.users)
node_idx += 1
node_dict[node_idx] = []
Returns:
List[List[Node]]: List of list, each inside list of Node presents
the actual 'node' in linearized manner.
"""
continue
def _is_sink() -> bool:
"""Check if we can free all dependencies
# boundary case
if len(status_dict) == 0:
# current node region end point = next node region start point
# i.e. n1-n2-n3-... type node, each node contains only one op
if last_dict_len == 1:
if len(node_dict[node_idx]) > 0:
node_idx += 1
node_dict[node_idx] = []
node_dict[node_idx].append(node)
status_dict[node.name] = list(node.users)
Returns:
bool
"""
continue
return not sum([v for _, v in deps.items()])
# n1-n2-n3, if n1 has multiple ops, the last op in n1 will be
# the one who is able to clean all others in status_dict
# and as the last_dict_len > 1, there are multiple ops are used
# by this node, we view it as the end of one node and start a new node
else:
deps = {}
linearized_nodes = []
region = []
node_dict[node_idx].append(node)
status_dict[node.name] = list(node.users)
node_idx += 1
node_dict[node_idx] = []
for n in gm.graph.nodes:
for n_par in n._input_nodes:
deps[n_par] -= 1
region.append(n)
continue
# if the node could free all dependencies in graph
# we could begin a new node
if _is_sink():
linearized_nodes.append(region)
region = []
else:
# currently I will use bigger node structure
# if the following region is activated, the node will be smaller
#################################################
# if last_dict_len == 1:
# if len(node_dict[node_idx]) > 0:
# node_idx += 1
# node_dict[node_idx] = [node]
# status_dict[node.name] = list(node.users)
#
# continue
#################################################
deps[n] = len(n.users)
# in-node case, as the current node can not clean status_dict
# we view it as in-node status, the node will be appended to the
# current node_idx
node_dict[node_idx].append(node)
status_dict[node.name] = list(node.users)
continue
# If the output node use multiple nodes, there might be an
# empty node after the output node
if len(node_dict[node_idx]) == 0:
node_dict.pop[node_idx]
node_idx -= 1
# pop the last two nodes
node_dict.pop(0)
node_dict.pop(node_idx)
return node_dict
# Remove input
linearized_nodes = linearized_nodes[1:-1]
return linearized_nodes