mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-10 13:30:19 +00:00
@@ -3,10 +3,12 @@ from typing import Any, Callable, Dict, Iterable, List, Tuple
|
||||
from torch.fx.node import Node
|
||||
|
||||
|
||||
def flat_list(inputs):
|
||||
def flat_list(inputs: Any) -> List:
|
||||
"""
|
||||
flat a list by recursion
|
||||
"""
|
||||
if not (isinstance(inputs, list) or isinstance(inputs, set) or isinstance(inputs, tuple)):
|
||||
return [inputs]
|
||||
res = []
|
||||
for i in inputs:
|
||||
if isinstance(i, list) or isinstance(i, set) or isinstance(i, tuple):
|
||||
@@ -16,7 +18,7 @@ def flat_list(inputs):
|
||||
return res
|
||||
|
||||
|
||||
def find_first_tensor_arg(node):
|
||||
def find_first_tensor_arg(node: Node) -> Node:
|
||||
"""
|
||||
Find the first input tensor arg for a node
|
||||
"""
|
||||
@@ -26,7 +28,7 @@ def find_first_tensor_arg(node):
|
||||
raise RuntimeError()
|
||||
|
||||
|
||||
def is_non_compute_node(node):
|
||||
def is_non_compute_node(node: Node) -> bool:
|
||||
if any(i in node.op for i in ["placeholder", "get_attr", "output"]) or any(i in node.name for i in ["getattr"]):
|
||||
return True
|
||||
if "getitem" in node.name:
|
||||
@@ -34,16 +36,26 @@ def is_non_compute_node(node):
|
||||
for node_arg in node_args:
|
||||
if any(i == str(node_arg) for i in ["None", "Ellipsis"]):
|
||||
return False
|
||||
if "slice" in str(node_arg):
|
||||
return False
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def get_node_shape(node):
|
||||
def get_node_shape(node: Node) -> List:
|
||||
if hasattr(node.meta["tensor_meta"], "shape"):
|
||||
return node.meta["tensor_meta"].shape
|
||||
return None
|
||||
|
||||
|
||||
def is_non_memory_node(node: Node) -> bool:
|
||||
if "getitem" in node.name:
|
||||
return True
|
||||
if "output" in node.op:
|
||||
return True
|
||||
return is_non_compute_node(node)
|
||||
|
||||
|
||||
def is_non_compute_node_except_placeholder(node):
|
||||
if "placeholder" in node.op:
|
||||
return False
|
||||
|
Reference in New Issue
Block a user