mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-10 13:30:19 +00:00
[autochunk] support multi outputs chunk search (#2538)
Support multi outputs chunk search. Previously we only support single output chunk search. It is more flexible and improve performance by a large margin. For transformer, we reduce memory by 40% than previous search strategy. 1. rewrite search strategy to support multi outputs chunk search 2. fix many, many bugs 3. update tests
This commit is contained in:
@@ -9,6 +9,59 @@ NON_COMPUTE_NAME = ["getattr", "eq", "_assert_is_none", "_assert", "finfo", "siz
|
||||
logger = get_dist_logger()
|
||||
|
||||
|
||||
class NodeMgr(object):
|
||||
|
||||
def __init__(self, gm) -> None:
|
||||
self._node_list = list(gm.graph.nodes)
|
||||
self._node_dict = {}
|
||||
self._set_node_dict()
|
||||
|
||||
def _set_node_dict(self) -> None:
|
||||
"""
|
||||
create a dict {node_name: node_idx}
|
||||
"""
|
||||
self._node_dict.clear()
|
||||
for idx, node in enumerate(self._node_list):
|
||||
self._node_dict[node.name] = idx
|
||||
|
||||
def find_node_idx(self, node: Node) -> int:
|
||||
"""
|
||||
find node's index
|
||||
"""
|
||||
return self._node_dict[node.name]
|
||||
|
||||
def find_node_idx_by_name(self, node_name: str) -> int:
|
||||
"""
|
||||
find node's index
|
||||
"""
|
||||
return self._node_dict[node_name]
|
||||
|
||||
def get_node_by_idx(self, idx: int) -> Node:
|
||||
"""
|
||||
get a node by index
|
||||
"""
|
||||
return self._node_list[idx]
|
||||
|
||||
def get_node_slice_by_idx(self, start: int, end: int) -> List[Node]:
|
||||
"""
|
||||
get a slice of node by index
|
||||
"""
|
||||
return self._node_list[start:end]
|
||||
|
||||
def get_node_list(self) -> List:
|
||||
"""
|
||||
get full node list
|
||||
"""
|
||||
return self._node_list
|
||||
|
||||
def update_node_list(self, node_list: List) -> None:
|
||||
"""
|
||||
update node list, reset node dict
|
||||
"""
|
||||
self._node_list = node_list
|
||||
self._set_node_dict()
|
||||
|
||||
|
||||
def get_logger() -> Any:
|
||||
return logger
|
||||
|
||||
@@ -42,6 +95,8 @@ def is_non_compute_node(node: Node) -> bool:
|
||||
if any(i == node.op for i in NON_COMPUTE_OP) or any(i == get_node_name(node) for i in NON_COMPUTE_NAME):
|
||||
return True
|
||||
if "getitem" in node.name:
|
||||
if get_node_shape(node) is not None:
|
||||
return False
|
||||
node_args = flat_list(node.args[1:])
|
||||
for node_arg in node_args:
|
||||
if any(i == str(node_arg) for i in ["None", "Ellipsis"]):
|
||||
@@ -53,6 +108,8 @@ def is_non_compute_node(node: Node) -> bool:
|
||||
|
||||
|
||||
def get_node_shape(node: Node) -> List:
|
||||
if get_node_name(node) == "split":
|
||||
return node.meta["tensor_meta"][0].shape
|
||||
if hasattr(node.meta["tensor_meta"], "shape"):
|
||||
return node.meta["tensor_meta"].shape
|
||||
return None
|
||||
@@ -78,7 +135,7 @@ def is_non_compute_node_except_placeholder_output(node: Node) -> bool:
|
||||
return is_non_compute_node_except_placeholder(node)
|
||||
|
||||
|
||||
def find_idx_by_name(name: str, nodes_list: List) -> int:
|
||||
def find_node_idx(name: str, nodes_list: List) -> int:
|
||||
for idx, node in enumerate(nodes_list):
|
||||
if node.name == name:
|
||||
return idx
|
||||
@@ -162,3 +219,28 @@ def get_node_name(node: Node) -> str:
|
||||
else:
|
||||
break
|
||||
return node_name
|
||||
|
||||
|
||||
def find_tensor_node(node_list: List[Node]) -> List[Node]:
|
||||
"""
|
||||
find tensor nodes from a node list
|
||||
"""
|
||||
out = []
|
||||
for node in node_list:
|
||||
if get_node_shape(node) is not None:
|
||||
out.append(node)
|
||||
return out
|
||||
|
||||
|
||||
def find_tensor_shape_node(node_list: List[Node]) -> List[Node]:
|
||||
"""
|
||||
find tensor and shape nodes from a node list
|
||||
"""
|
||||
out = []
|
||||
for node in node_list:
|
||||
if get_node_shape(node) is not None:
|
||||
out.append(node)
|
||||
elif len(node.meta['fwd_out']) > 0 and isinstance(node.meta['fwd_out'], list) and isinstance(
|
||||
node.meta['fwd_out'][0], int):
|
||||
out.append(node)
|
||||
return out
|
||||
|
Reference in New Issue
Block a user