[auto-chunk] support extramsa (#3) (#2504)

This commit is contained in:
oahzxl
2023-01-20 10:13:03 +08:00
committed by GitHub
parent 0f02b8c6e6
commit 72341e65f4
8 changed files with 283 additions and 54 deletions

View File

@@ -3,10 +3,12 @@ from typing import Any, Callable, Dict, Iterable, List, Tuple
from torch.fx.node import Node
def flat_list(inputs):
def flat_list(inputs: Any) -> List:
"""
flat a list by recursion
"""
if not (isinstance(inputs, list) or isinstance(inputs, set) or isinstance(inputs, tuple)):
return [inputs]
res = []
for i in inputs:
if isinstance(i, list) or isinstance(i, set) or isinstance(i, tuple):
@@ -16,7 +18,7 @@ def flat_list(inputs):
return res
def find_first_tensor_arg(node):
def find_first_tensor_arg(node: Node) -> Node:
"""
Find the first input tensor arg for a node
"""
@@ -26,7 +28,7 @@ def find_first_tensor_arg(node):
raise RuntimeError()
def is_non_compute_node(node):
def is_non_compute_node(node: Node) -> bool:
if any(i in node.op for i in ["placeholder", "get_attr", "output"]) or any(i in node.name for i in ["getattr"]):
return True
if "getitem" in node.name:
@@ -34,16 +36,26 @@ def is_non_compute_node(node):
for node_arg in node_args:
if any(i == str(node_arg) for i in ["None", "Ellipsis"]):
return False
if "slice" in str(node_arg):
return False
return True
return False
def get_node_shape(node):
def get_node_shape(node: Node) -> List:
if hasattr(node.meta["tensor_meta"], "shape"):
return node.meta["tensor_meta"].shape
return None
def is_non_memory_node(node: Node) -> bool:
if "getitem" in node.name:
return True
if "output" in node.op:
return True
return is_non_compute_node(node)
def is_non_compute_node_except_placeholder(node):
if "placeholder" in node.op:
return False