[auto-chunk] support extramsa (#3) (#2504)

This commit is contained in:
oahzxl
2023-01-20 10:13:03 +08:00
committed by GitHub
parent 0f02b8c6e6
commit 72341e65f4
8 changed files with 283 additions and 54 deletions

View File

@@ -118,16 +118,34 @@ class TraceFlow(object):
def _assgin_single_node_flow(
self,
arg_node,
start_idx,
end_idx,
cur_node_dim,
cur_node_compute,
cur_node_source,
cur_node_fix_dim,
all_node_info,
next_node_list,
):
arg_node: Node,
start_idx: int,
end_idx: int,
cur_node_dim: int,
cur_node_compute: Dict,
cur_node_source: Dict,
cur_node_fix_dim: List,
all_node_info: Dict,
next_node_list: List,
) -> bool:
"""
Given the current node and one of its arg node,
this function finds out arg node's chunk dim and fix dim
Args:
arg_node (Node): input node
start_idx (int): chunk region start
end_idx (int): chunk region end
cur_node_dim (int): current node chunk dim
cur_node_compute (Dict): current node compute dict
cur_node_source (Dict): current node source dict
cur_node_fix_dim (List): current node fix dim
all_node_info (Dict): all node chunk info in the chunk region
next_node_list (List)
Returns:
bool: True if this node can be added to the flow, vice versa.
"""
arg_idx = find_idx_by_name(arg_node.name, self.trace_indice.node_list)
# arg in chunk range or be inputs
if not (start_idx <= arg_idx < end_idx):
@@ -142,6 +160,9 @@ class TraceFlow(object):
arg_dim = None
else:
arg_dim = cur_node_source[cur_node_dim][arg_idx][0]
# chunk dim should be None if shape size is 1
if get_node_shape(arg_node)[arg_dim] == 1:
arg_dim = None
else:
arg_dim = None
@@ -184,7 +205,7 @@ class TraceFlow(object):
# get all valid args
arg_list = []
for arg in cur_node.args:
for arg in cur_node.all_input_nodes:
if type(arg) != type(cur_node):
continue
if is_non_compute_node(arg):