mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-10 21:40:02 +00:00
[autochunk] support autochunk on evoformer (#2497)
This commit is contained in:
@@ -3,14 +3,14 @@ from typing import Any, Callable, Dict, Iterable, List, Tuple
|
||||
from torch.fx.node import Node
|
||||
|
||||
|
||||
def unflat_list(inputs):
|
||||
def flat_list(inputs):
|
||||
"""
|
||||
unflat a list by recursion
|
||||
flat a list by recursion
|
||||
"""
|
||||
res = []
|
||||
for i in inputs:
|
||||
if isinstance(i, list) or isinstance(i, set) or isinstance(i, tuple):
|
||||
res.extend(unflat_list(i))
|
||||
res.extend(flat_list(i))
|
||||
else:
|
||||
res.append(i)
|
||||
return res
|
||||
@@ -27,8 +27,13 @@ def find_first_tensor_arg(node):
|
||||
|
||||
|
||||
def is_non_compute_node(node):
|
||||
if any(i in node.op for i in ["placeholder", "get_attr", "output"]) or any(
|
||||
i in node.name for i in ["getitem", "getattr"]):
|
||||
if any(i in node.op for i in ["placeholder", "get_attr", "output"]) or any(i in node.name for i in ["getattr"]):
|
||||
return True
|
||||
if "getitem" in node.name:
|
||||
node_args = flat_list(node.args[1:])
|
||||
for node_arg in node_args:
|
||||
if any(i == str(node_arg) for i in ["None", "Ellipsis"]):
|
||||
return False
|
||||
return True
|
||||
return False
|
||||
|
||||
@@ -40,15 +45,15 @@ def get_node_shape(node):
|
||||
|
||||
|
||||
def is_non_compute_node_except_placeholder(node):
|
||||
if any(i in node.op for i in ["get_attr", "output"]) or any(i in node.name for i in ["getitem", "getattr"]):
|
||||
return True
|
||||
return False
|
||||
if "placeholder" in node.op:
|
||||
return False
|
||||
return is_non_compute_node(node)
|
||||
|
||||
|
||||
def is_non_compute_node_except_placeholder_output(node):
|
||||
if any(i in node.op for i in ["get_attr"]) or any(i in node.name for i in ["getitem", "getattr"]):
|
||||
return True
|
||||
return False
|
||||
if "output" in node.op:
|
||||
return False
|
||||
return is_non_compute_node_except_placeholder(node)
|
||||
|
||||
|
||||
def find_idx_by_name(name, nodes_list):
|
||||
|
Reference in New Issue
Block a user