code style

This commit is contained in:
oahzxl
2023-01-06 17:31:59 +08:00
parent a6cdbf9161
commit c3a2bf48b4
5 changed files with 46 additions and 36 deletions

View File

@@ -81,7 +81,9 @@ class TraceFlow(object):
input_dim_after_node = {}
for input_node_idx, input_node in enumerate(chunk_infos["inputs"]):
for k, v in chunk_infos["inputs_dim"][input_node_idx].items():
inherit_dim = self._find_inherit_dim(input_node, v, self.trace_index.node_list[k])
inherit_dim = self._find_inherit_dim(
input_node, v, self.trace_index.node_list[k]
)
if inherit_dim:
input_dim_after_node[k] = inherit_dim
@@ -217,7 +219,9 @@ class TraceFlow(object):
for arg in arg_list:
if not (
start_idx
<= find_idx_by_name(arg.name, self.trace_index.node_list)
<= find_idx_by_name(
arg.name, self.trace_index.node_list
)
< end_idx
):
continue
@@ -255,7 +259,9 @@ class TraceFlow(object):
if start_idx <= user_idx <= end_idx:
chunk_dim = all_node_info[user]["chunk_dim"]
if chunk_dim is not None:
user_source = self.trace_index._find_source_trace_from_node(user)[chunk_dim]
user_source = self.trace_index._find_source_trace_from_node(
user
)[chunk_dim]
if input_node_idx in user_source:
input_dict[user_idx] = user_source[input_node_idx]
else: