mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2026-04-12 15:14:55 +00:00
Support multi outputs chunk search. Previously we only support single output chunk search. It is more flexible and improve performance by a large margin. For transformer, we reduce memory by 40% than previous search strategy. 1. rewrite search strategy to support multi outputs chunk search 2. fix many, many bugs 3. update tests
247 lines
7.0 KiB
Python
247 lines
7.0 KiB
Python
from typing import Any, Callable, Dict, Iterable, List, Tuple, Union
|
|
|
|
from torch.fx.node import Node
|
|
|
|
from colossalai.logging import get_dist_logger
|
|
|
|
NON_COMPUTE_OP = ["placeholder", "get_attr", "output"]
|
|
NON_COMPUTE_NAME = ["getattr", "eq", "_assert_is_none", "_assert", "finfo", "size"]
|
|
logger = get_dist_logger()
|
|
|
|
|
|
class NodeMgr(object):
|
|
|
|
def __init__(self, gm) -> None:
|
|
self._node_list = list(gm.graph.nodes)
|
|
self._node_dict = {}
|
|
self._set_node_dict()
|
|
|
|
def _set_node_dict(self) -> None:
|
|
"""
|
|
create a dict {node_name: node_idx}
|
|
"""
|
|
self._node_dict.clear()
|
|
for idx, node in enumerate(self._node_list):
|
|
self._node_dict[node.name] = idx
|
|
|
|
def find_node_idx(self, node: Node) -> int:
|
|
"""
|
|
find node's index
|
|
"""
|
|
return self._node_dict[node.name]
|
|
|
|
def find_node_idx_by_name(self, node_name: str) -> int:
|
|
"""
|
|
find node's index
|
|
"""
|
|
return self._node_dict[node_name]
|
|
|
|
def get_node_by_idx(self, idx: int) -> Node:
|
|
"""
|
|
get a node by index
|
|
"""
|
|
return self._node_list[idx]
|
|
|
|
def get_node_slice_by_idx(self, start: int, end: int) -> List[Node]:
|
|
"""
|
|
get a slice of node by index
|
|
"""
|
|
return self._node_list[start:end]
|
|
|
|
def get_node_list(self) -> List:
|
|
"""
|
|
get full node list
|
|
"""
|
|
return self._node_list
|
|
|
|
def update_node_list(self, node_list: List) -> None:
|
|
"""
|
|
update node list, reset node dict
|
|
"""
|
|
self._node_list = node_list
|
|
self._set_node_dict()
|
|
|
|
|
|
def get_logger() -> Any:
|
|
return logger
|
|
|
|
|
|
def flat_list(inputs: Any) -> List:
|
|
"""
|
|
flat a list by recursion
|
|
"""
|
|
if not (isinstance(inputs, list) or isinstance(inputs, set) or isinstance(inputs, tuple)):
|
|
return [inputs]
|
|
res = []
|
|
for i in inputs:
|
|
if isinstance(i, list) or isinstance(i, set) or isinstance(i, tuple):
|
|
res.extend(flat_list(i))
|
|
else:
|
|
res.append(i)
|
|
return res
|
|
|
|
|
|
def find_first_tensor_arg(node: Node) -> Node:
|
|
"""
|
|
Find the first input tensor arg for a node
|
|
"""
|
|
for arg in node.args:
|
|
if type(arg) == type(node):
|
|
return arg
|
|
raise RuntimeError()
|
|
|
|
|
|
def is_non_compute_node(node: Node) -> bool:
|
|
if any(i == node.op for i in NON_COMPUTE_OP) or any(i == get_node_name(node) for i in NON_COMPUTE_NAME):
|
|
return True
|
|
if "getitem" in node.name:
|
|
if get_node_shape(node) is not None:
|
|
return False
|
|
node_args = flat_list(node.args[1:])
|
|
for node_arg in node_args:
|
|
if any(i == str(node_arg) for i in ["None", "Ellipsis"]):
|
|
return False
|
|
if "slice" in str(node_arg):
|
|
return False
|
|
return True
|
|
return False
|
|
|
|
|
|
def get_node_shape(node: Node) -> List:
|
|
if get_node_name(node) == "split":
|
|
return node.meta["tensor_meta"][0].shape
|
|
if hasattr(node.meta["tensor_meta"], "shape"):
|
|
return node.meta["tensor_meta"].shape
|
|
return None
|
|
|
|
|
|
def is_non_memory_node(node: Node) -> bool:
|
|
if "getitem" in node.name:
|
|
return True
|
|
if "output" in node.op:
|
|
return True
|
|
return is_non_compute_node(node)
|
|
|
|
|
|
def is_non_compute_node_except_placeholder(node: Node) -> bool:
|
|
if "placeholder" in node.op:
|
|
return False
|
|
return is_non_compute_node(node)
|
|
|
|
|
|
def is_non_compute_node_except_placeholder_output(node: Node) -> bool:
|
|
if "output" in node.op:
|
|
return False
|
|
return is_non_compute_node_except_placeholder(node)
|
|
|
|
|
|
def find_node_idx(name: str, nodes_list: List) -> int:
|
|
for idx, node in enumerate(nodes_list):
|
|
if node.name == name:
|
|
return idx
|
|
raise RuntimeError("name %s not found in node list" % name)
|
|
|
|
|
|
def delete_free_var_from_last_use(user_to_last_uses: Dict) -> None:
|
|
for key, value in user_to_last_uses.items():
|
|
for n in value:
|
|
if n.op == "placeholder":
|
|
user_to_last_uses[key].remove(n)
|
|
|
|
|
|
def find_chunk_all_input_nodes(nodes: List[Node]) -> List:
|
|
"""
|
|
Find non-compute input and output node names.
|
|
input nodes are nodes used in the list
|
|
output nodes are nodes will use nodes in the list
|
|
"""
|
|
input_nodes = []
|
|
for node in nodes:
|
|
for input_node in node._input_nodes.keys():
|
|
if input_node not in nodes and input_node not in input_nodes:
|
|
input_nodes.append(input_node)
|
|
return input_nodes
|
|
|
|
|
|
def find_chunk_compute_input_and_output_nodes(nodes: List[Node]) -> Union[List, List]:
|
|
"""
|
|
Find non-compute input and output node names.
|
|
input nodes are nodes used in the list
|
|
output nodes are nodes will use nodes in the list
|
|
"""
|
|
input_nodes = []
|
|
output_nodes = []
|
|
|
|
# if a node has an input node which is not in the node list
|
|
# we treat that input node as the input of the checkpoint function
|
|
for node in nodes:
|
|
for input_node in node._input_nodes.keys():
|
|
if (input_node not in nodes and input_node not in input_nodes
|
|
and not is_non_compute_node_except_placeholder(input_node)):
|
|
input_nodes.append(input_node)
|
|
|
|
# if a node has a user node which is not in the node list
|
|
# we treat that user node as the node receiving the current node output
|
|
for node in nodes:
|
|
for output_node in node.users.keys():
|
|
if (output_node not in nodes and node not in output_nodes
|
|
and not is_non_compute_node_except_placeholder_output(output_node)):
|
|
output_nodes.append(node)
|
|
|
|
return input_nodes, output_nodes
|
|
|
|
|
|
def get_module_node_name(node: Node) -> str:
|
|
"""
|
|
get module class name
|
|
"""
|
|
node_targets = node.target.split(".")
|
|
module = node.graph.owning_module
|
|
for i in node_targets:
|
|
module = getattr(module, i)
|
|
module_name = str(module.__class__).split(".")[-1][:-2]
|
|
module_name = module_name.lower()
|
|
return module_name
|
|
|
|
|
|
def get_node_name(node: Node) -> str:
|
|
"""
|
|
get node name
|
|
"""
|
|
node_name = node.name
|
|
if "_" in node_name:
|
|
for i in range(len(node_name) - 1, -1, -1):
|
|
if node_name[i] == "_":
|
|
node_name = node_name[:i]
|
|
break
|
|
elif node_name[i] in ["1", "2", "3", "4", "5", "6", "7", "8", "9", "0"]:
|
|
continue
|
|
else:
|
|
break
|
|
return node_name
|
|
|
|
|
|
def find_tensor_node(node_list: List[Node]) -> List[Node]:
|
|
"""
|
|
find tensor nodes from a node list
|
|
"""
|
|
out = []
|
|
for node in node_list:
|
|
if get_node_shape(node) is not None:
|
|
out.append(node)
|
|
return out
|
|
|
|
|
|
def find_tensor_shape_node(node_list: List[Node]) -> List[Node]:
|
|
"""
|
|
find tensor and shape nodes from a node list
|
|
"""
|
|
out = []
|
|
for node in node_list:
|
|
if get_node_shape(node) is not None:
|
|
out.append(node)
|
|
elif len(node.meta['fwd_out']) > 0 and isinstance(node.meta['fwd_out'], list) and isinstance(
|
|
node.meta['fwd_out'][0], int):
|
|
out.append(node)
|
|
return out
|