mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-10 13:30:19 +00:00
[autochunk] support transformer (#2526)
This commit is contained in:
@@ -1,13 +1,15 @@
|
||||
from typing import Any, Callable, Dict, Iterable, List, Tuple
|
||||
from typing import Any, Callable, Dict, Iterable, List, Tuple, Union
|
||||
|
||||
from torch.fx.node import Node
|
||||
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
NON_COMPUTE_OP = ["placeholder", "get_attr", "output"]
|
||||
NON_COMPUTE_NAME = ["getattr", "eq", "_assert_is_none", "_assert", "finfo", "size"]
|
||||
logger = get_dist_logger()
|
||||
|
||||
|
||||
def get_logger():
|
||||
def get_logger() -> Any:
|
||||
return logger
|
||||
|
||||
|
||||
@@ -37,7 +39,7 @@ def find_first_tensor_arg(node: Node) -> Node:
|
||||
|
||||
|
||||
def is_non_compute_node(node: Node) -> bool:
|
||||
if any(i in node.op for i in ["placeholder", "get_attr", "output"]) or any(i in node.name for i in ["getattr"]):
|
||||
if any(i == node.op for i in NON_COMPUTE_OP) or any(i == get_node_name(node) for i in NON_COMPUTE_NAME):
|
||||
return True
|
||||
if "getitem" in node.name:
|
||||
node_args = flat_list(node.args[1:])
|
||||
@@ -64,33 +66,33 @@ def is_non_memory_node(node: Node) -> bool:
|
||||
return is_non_compute_node(node)
|
||||
|
||||
|
||||
def is_non_compute_node_except_placeholder(node):
|
||||
def is_non_compute_node_except_placeholder(node: Node) -> bool:
|
||||
if "placeholder" in node.op:
|
||||
return False
|
||||
return is_non_compute_node(node)
|
||||
|
||||
|
||||
def is_non_compute_node_except_placeholder_output(node):
|
||||
def is_non_compute_node_except_placeholder_output(node: Node) -> bool:
|
||||
if "output" in node.op:
|
||||
return False
|
||||
return is_non_compute_node_except_placeholder(node)
|
||||
|
||||
|
||||
def find_idx_by_name(name, nodes_list):
|
||||
def find_idx_by_name(name: str, nodes_list: List) -> int:
|
||||
for idx, node in enumerate(nodes_list):
|
||||
if node.name == name:
|
||||
return idx
|
||||
raise RuntimeError("name %s not found in node list" % name)
|
||||
|
||||
|
||||
def delete_free_var_from_last_use(user_to_last_uses):
|
||||
def delete_free_var_from_last_use(user_to_last_uses: Dict) -> None:
|
||||
for key, value in user_to_last_uses.items():
|
||||
for n in value:
|
||||
if n.op == "placeholder":
|
||||
user_to_last_uses[key].remove(n)
|
||||
|
||||
|
||||
def find_chunk_all_input_nodes(nodes: List[Node]):
|
||||
def find_chunk_all_input_nodes(nodes: List[Node]) -> List:
|
||||
"""
|
||||
Find non-compute input and output node names.
|
||||
input nodes are nodes used in the list
|
||||
@@ -104,7 +106,7 @@ def find_chunk_all_input_nodes(nodes: List[Node]):
|
||||
return input_nodes
|
||||
|
||||
|
||||
def find_chunk_compute_input_and_output_nodes(nodes: List[Node]):
|
||||
def find_chunk_compute_input_and_output_nodes(nodes: List[Node]) -> Union[List, List]:
|
||||
"""
|
||||
Find non-compute input and output node names.
|
||||
input nodes are nodes used in the list
|
||||
@@ -130,3 +132,33 @@ def find_chunk_compute_input_and_output_nodes(nodes: List[Node]):
|
||||
output_nodes.append(node)
|
||||
|
||||
return input_nodes, output_nodes
|
||||
|
||||
|
||||
def get_module_node_name(node: Node) -> str:
|
||||
"""
|
||||
get module class name
|
||||
"""
|
||||
node_targets = node.target.split(".")
|
||||
module = node.graph.owning_module
|
||||
for i in node_targets:
|
||||
module = getattr(module, i)
|
||||
module_name = str(module.__class__).split(".")[-1][:-2]
|
||||
module_name = module_name.lower()
|
||||
return module_name
|
||||
|
||||
|
||||
def get_node_name(node: Node) -> str:
|
||||
"""
|
||||
get node name
|
||||
"""
|
||||
node_name = node.name
|
||||
if "_" in node_name:
|
||||
for i in range(len(node_name) - 1, -1, -1):
|
||||
if node_name[i] == "_":
|
||||
node_name = node_name[:i]
|
||||
break
|
||||
elif node_name[i] in ["1", "2", "3", "4", "5", "6", "7", "8", "9", "0"]:
|
||||
continue
|
||||
else:
|
||||
break
|
||||
return node_name
|
||||
|
Reference in New Issue
Block a user