[autochunk] support multi outputs chunk search (#2538)

Support multi outputs chunk search. Previously we only support single output chunk search. It is more flexible and improve performance by a large margin. For transformer, we reduce memory by 40% than previous search strategy.

1. rewrite search strategy to support multi outputs chunk search
2. fix many, many bugs
3. update tests
This commit is contained in:
oahzxl
2023-02-01 13:18:51 +08:00
committed by GitHub
parent f477a14f4a
commit 05671fcb42
14 changed files with 428 additions and 258 deletions

View File

@@ -3,14 +3,7 @@ from typing import Dict, List, Tuple
from torch.fx.node import Node
from .utils import (
find_first_tensor_arg,
find_idx_by_name,
flat_list,
get_module_node_name,
get_node_name,
get_node_shape,
)
from .utils import NodeMgr, find_first_tensor_arg, flat_list, get_module_node_name, get_node_name, get_node_shape
class TraceIndice(object):
@@ -35,8 +28,8 @@ class TraceIndice(object):
node_list (List)
"""
def __init__(self, node_list: List[Node]) -> None:
self.node_list = node_list
def __init__(self, node_mgr: NodeMgr) -> None:
self.node_mgr = node_mgr
self.indice_trace_list = self._init_indice_trace_list()
self.indice_view_list = {}
self.indice_count = -1
@@ -45,7 +38,7 @@ class TraceIndice(object):
def _init_indice_trace_list(self) -> List:
indice_trace_list = []
for n in self.node_list:
for n in self.node_mgr.get_node_list():
if get_node_shape(n) != None:
cur_trace = {
"indice": [None for _ in range(len(get_node_shape(n)))],
@@ -99,7 +92,7 @@ class TraceIndice(object):
node_from_trace_source = self._find_source_trace_from_node(node_from)
node_to_dim = self._transform_indice(node_to, node_to_dim)
node_to_trace_source = self._find_source_trace_from_node(node_to)
node_from_idx = find_idx_by_name(node_from.name, self.node_list)
node_from_idx = self.node_mgr.find_node_idx(node_from)
if init:
node_to_trace_source[node_to_dim] = {}
# add dim to cur new source
@@ -200,7 +193,7 @@ class TraceIndice(object):
idx (list): idx of the node
compute (list): computed idx of the node.
"""
node_idx = find_idx_by_name(node.name, self.node_list)
node_idx = self.node_mgr.find_node_idx(node)
node_dict = self.indice_trace_list[node_idx]
return node_dict
@@ -214,7 +207,7 @@ class TraceIndice(object):
idx (list): idx of the node
compute (list): computed idx of the node.
"""
node_idx = find_idx_by_name(node.name, self.node_list)
node_idx = self.node_mgr.find_node_idx(node)
node_dict = self.indice_trace_list[node_idx]
return node_dict["source"]
@@ -227,7 +220,7 @@ class TraceIndice(object):
Returns:
idx (list): idx of the node
"""
node_idx = find_idx_by_name(node.name, self.node_list)
node_idx = self.node_mgr.find_node_idx(node)
return self.indice_trace_list[node_idx]["indice"]
def _find_compute_trace_from_node(self, node: Node) -> List:
@@ -239,7 +232,7 @@ class TraceIndice(object):
Returns:
compute (list): computed idx of the node.
"""
node_idx = find_idx_by_name(node.name, self.node_list)
node_idx = self.node_mgr.find_node_idx(node)
return self.indice_trace_list[node_idx]["compute"]
def _assign_indice_as_input(self, node: Node, node_idx: int, input_node=None) -> None:
@@ -454,8 +447,6 @@ class TraceIndice(object):
node (node)
node_idx (int)
"""
for _ in range(len(get_node_shape(node.args[0]))):
self._add_dim(node_idx, 0)
self._assign_indice_as_input(node, node_idx)
dim_idx = node.kwargs["dim"]
self._del_dim(node_idx, dim_idx)
@@ -702,21 +693,20 @@ class TraceIndice(object):
if (view_dict["idx_to"] == idx_from and view_dict["dim_to"] == dim_from
and view_dict["dim_from"] == dim_to):
# inheirt indice from current node
for dim_to_i in dim_to:
for dim_from_i in dim_from:
self._inherit_indice(origin_node, dim_from_i, node, dim_to_i, init=False)
if len_diff == 1:
if origin_shape[dim_from[0]] == 1:
self._inherit_indice(origin_node, dim_from[1], node, dim_to[0], init=False)
elif origin_shape[dim_from[1]] == 1:
self._inherit_indice(origin_node, dim_from[0], node, dim_to[0], init=False)
elif len_diff == -1:
if target_shape[dim_to[0]] == 1:
self._inherit_indice(origin_node, dim_from[0], node, dim_to[1], init=False)
elif target_shape[dim_to[1]] == 1:
self._inherit_indice(origin_node, dim_from[0], node, dim_to[0], init=False)
# inherid indice from input node of last view
for dim_to_i in dim_to:
self._inherit_indice(view_node.args[0], dim_to_i, node, dim_to_i, init=False)
# inherit computation
compute_log = self._find_compute_trace_from_node(origin_node)
for i in dim_from:
if origin_trace[i] in compute_log:
for j in dim_to:
self._mark_computation(node, node_idx, [j])
break
# log view, not used now
view_dict = {
"idx_from": [origin_trace[i] for i in dim_from],
@@ -742,7 +732,7 @@ class TraceIndice(object):
active_nodes = self.active_node_list[trace_range[0]:trace_range[1] + 1]
active_nodes = set(flat_list(active_nodes))
active_nodes = [find_idx_by_name(i, self.node_list) for i in active_nodes]
active_nodes = [self.node_mgr.find_node_idx_by_name(i) for i in active_nodes]
for i in range(trace_range[0], trace_range[1] + 1):
trace = self.indice_trace_list[i]
# clear compute
@@ -758,7 +748,7 @@ class TraceIndice(object):
dim_source.pop(k)
def trace_indice(self) -> None:
for idx, node in enumerate(self.node_list):
for idx, node in enumerate(self.node_mgr.get_node_list()):
node_name = get_node_name(node)
if node.op == "placeholder":
self._assign_all_indice(node, idx)