[autochunk] support multi outputs chunk search (#2538)

Support multi outputs chunk search. Previously we only support single output chunk search. It is more flexible and improve performance by a large margin. For transformer, we reduce memory by 40% than previous search strategy.

1. rewrite search strategy to support multi outputs chunk search
2. fix many, many bugs
3. update tests
This commit is contained in:
oahzxl
2023-02-01 13:18:51 +08:00
committed by GitHub
parent f477a14f4a
commit 05671fcb42
14 changed files with 428 additions and 258 deletions

View File

@@ -9,6 +9,59 @@ NON_COMPUTE_NAME = ["getattr", "eq", "_assert_is_none", "_assert", "finfo", "siz
logger = get_dist_logger()
class NodeMgr(object):
def __init__(self, gm) -> None:
self._node_list = list(gm.graph.nodes)
self._node_dict = {}
self._set_node_dict()
def _set_node_dict(self) -> None:
"""
create a dict {node_name: node_idx}
"""
self._node_dict.clear()
for idx, node in enumerate(self._node_list):
self._node_dict[node.name] = idx
def find_node_idx(self, node: Node) -> int:
"""
find node's index
"""
return self._node_dict[node.name]
def find_node_idx_by_name(self, node_name: str) -> int:
"""
find node's index
"""
return self._node_dict[node_name]
def get_node_by_idx(self, idx: int) -> Node:
"""
get a node by index
"""
return self._node_list[idx]
def get_node_slice_by_idx(self, start: int, end: int) -> List[Node]:
"""
get a slice of node by index
"""
return self._node_list[start:end]
def get_node_list(self) -> List:
"""
get full node list
"""
return self._node_list
def update_node_list(self, node_list: List) -> None:
"""
update node list, reset node dict
"""
self._node_list = node_list
self._set_node_dict()
def get_logger() -> Any:
return logger
@@ -42,6 +95,8 @@ def is_non_compute_node(node: Node) -> bool:
if any(i == node.op for i in NON_COMPUTE_OP) or any(i == get_node_name(node) for i in NON_COMPUTE_NAME):
return True
if "getitem" in node.name:
if get_node_shape(node) is not None:
return False
node_args = flat_list(node.args[1:])
for node_arg in node_args:
if any(i == str(node_arg) for i in ["None", "Ellipsis"]):
@@ -53,6 +108,8 @@ def is_non_compute_node(node: Node) -> bool:
def get_node_shape(node: Node) -> List:
if get_node_name(node) == "split":
return node.meta["tensor_meta"][0].shape
if hasattr(node.meta["tensor_meta"], "shape"):
return node.meta["tensor_meta"].shape
return None
@@ -78,7 +135,7 @@ def is_non_compute_node_except_placeholder_output(node: Node) -> bool:
return is_non_compute_node_except_placeholder(node)
def find_idx_by_name(name: str, nodes_list: List) -> int:
def find_node_idx(name: str, nodes_list: List) -> int:
for idx, node in enumerate(nodes_list):
if node.name == name:
return idx
@@ -162,3 +219,28 @@ def get_node_name(node: Node) -> str:
else:
break
return node_name
def find_tensor_node(node_list: List[Node]) -> List[Node]:
"""
find tensor nodes from a node list
"""
out = []
for node in node_list:
if get_node_shape(node) is not None:
out.append(node)
return out
def find_tensor_shape_node(node_list: List[Node]) -> List[Node]:
"""
find tensor and shape nodes from a node list
"""
out = []
for node in node_list:
if get_node_shape(node) is not None:
out.append(node)
elif len(node.meta['fwd_out']) > 0 and isinstance(node.meta['fwd_out'], list) and isinstance(
node.meta['fwd_out'][0], int):
out.append(node)
return out