mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 11:32:10 +00:00
@@ -118,16 +118,34 @@ class TraceFlow(object):
|
||||
|
||||
def _assgin_single_node_flow(
|
||||
self,
|
||||
arg_node,
|
||||
start_idx,
|
||||
end_idx,
|
||||
cur_node_dim,
|
||||
cur_node_compute,
|
||||
cur_node_source,
|
||||
cur_node_fix_dim,
|
||||
all_node_info,
|
||||
next_node_list,
|
||||
):
|
||||
arg_node: Node,
|
||||
start_idx: int,
|
||||
end_idx: int,
|
||||
cur_node_dim: int,
|
||||
cur_node_compute: Dict,
|
||||
cur_node_source: Dict,
|
||||
cur_node_fix_dim: List,
|
||||
all_node_info: Dict,
|
||||
next_node_list: List,
|
||||
) -> bool:
|
||||
"""
|
||||
Given the current node and one of its arg node,
|
||||
this function finds out arg node's chunk dim and fix dim
|
||||
|
||||
Args:
|
||||
arg_node (Node): input node
|
||||
start_idx (int): chunk region start
|
||||
end_idx (int): chunk region end
|
||||
cur_node_dim (int): current node chunk dim
|
||||
cur_node_compute (Dict): current node compute dict
|
||||
cur_node_source (Dict): current node source dict
|
||||
cur_node_fix_dim (List): current node fix dim
|
||||
all_node_info (Dict): all node chunk info in the chunk region
|
||||
next_node_list (List)
|
||||
|
||||
Returns:
|
||||
bool: True if this node can be added to the flow, vice versa.
|
||||
"""
|
||||
arg_idx = find_idx_by_name(arg_node.name, self.trace_indice.node_list)
|
||||
# arg in chunk range or be inputs
|
||||
if not (start_idx <= arg_idx < end_idx):
|
||||
@@ -142,6 +160,9 @@ class TraceFlow(object):
|
||||
arg_dim = None
|
||||
else:
|
||||
arg_dim = cur_node_source[cur_node_dim][arg_idx][0]
|
||||
# chunk dim should be None if shape size is 1
|
||||
if get_node_shape(arg_node)[arg_dim] == 1:
|
||||
arg_dim = None
|
||||
else:
|
||||
arg_dim = None
|
||||
|
||||
@@ -184,7 +205,7 @@ class TraceFlow(object):
|
||||
|
||||
# get all valid args
|
||||
arg_list = []
|
||||
for arg in cur_node.args:
|
||||
for arg in cur_node.all_input_nodes:
|
||||
if type(arg) != type(cur_node):
|
||||
continue
|
||||
if is_non_compute_node(arg):
|
||||
|
Reference in New Issue
Block a user